@@ -342,30 +342,88 @@ significantly more expensive than `x*y+z`. `fma` is used to improve accuracy in
342
342
algorithms. See [`muladd`](@ref).
343
343
"""
344
344
function fma end
345
+ function fma_emulated (a:: Float32 , b:: Float32 , c:: Float32 ):: Float32
346
+ ab = Float64 (a) * b
347
+ res = ab+ c
348
+ reinterpret (UInt64, res)& 0x1fff_ffff != 0x1000_0000 && return res
349
+ # yes error compensation is necessary. It sucks
350
+ reslo = abs (c)> abs (ab) ? ab- (res - c) : c- (res - ab)
351
+ res = iszero (reslo) ? res : (signbit (reslo) ? prevfloat (res) : nextfloat (res))
352
+ return res
353
+ end
354
+
355
+ """ Splits a Float64 into a hi bit and a low bit where the high bit has 27 trailing 0s and the low bit has 26 trailing 0s"""
356
+ @inline function splitbits (x:: Float64 )
357
+ hi = reinterpret (Float64, reinterpret (UInt64, x) & 0xffff_ffff_f800_0000 )
358
+ return hi, x- hi
359
+ end
360
+
361
+ @inline function twomul (a:: Float64 , b:: Float64 )
362
+ ahi, alo = splitbits (a)
363
+ bhi, blo = splitbits (b)
364
+ abhi = a* b
365
+ blohi, blolo = splitbits (blo)
366
+ ablo = alo* blohi - (((abhi - ahi* bhi) - alo* bhi) - ahi* blo) + blolo* alo
367
+ return abhi, ablo
368
+ end
345
369
346
- fma_libm (x:: Float32 , y:: Float32 , z:: Float32 ) =
347
- ccall ((" fmaf" , libm_name), Float32, (Float32,Float32,Float32), x, y, z)
348
- fma_libm (x:: Float64 , y:: Float64 , z:: Float64 ) =
349
- ccall ((" fma" , libm_name), Float64, (Float64,Float64,Float64), x, y, z)
370
+ function fma_emulated (a:: Float64 , b:: Float64 ,c:: Float64 )
371
+ abhi, ablo = twomul (a,b)
372
+ if ! isfinite (abhi+ c) || isless (abs (abhi), nextfloat (0x1 p- 969 )) || issubnormal (a) || issubnormal (b)
373
+ (isfinite (a) && isfinite (b) && isfinite (c)) || return abhi+ c
374
+ (iszero (a) || iszero (b)) && return abhi+ c
375
+ bias = exponent (a) + exponent (b)
376
+ c_denorm = ldexp (c, - bias)
377
+ if isfinite (c_denorm)
378
+ # rescale a and b to [1,2), equivalent to ldexp(a, -exponent(a))
379
+ issubnormal (a) && (a *= 0x1 p52)
380
+ issubnormal (b) && (b *= 0x1 p52)
381
+ a = reinterpret (Float64, (reinterpret (UInt64, a) & 0x800fffffffffffff ) | 0x3ff0000000000000 )
382
+ b = reinterpret (Float64, (reinterpret (UInt64, b) & 0x800fffffffffffff ) | 0x3ff0000000000000 )
383
+ c = c_denorm
384
+ abhi, ablo = twomul (a,b)
385
+ r = abhi+ c
386
+ s = (abs (abhi) > abs (c)) ? (abhi- r+ c+ ablo) : (c- r+ abhi+ ablo)
387
+ sumhi = r+ s
388
+ # If result is subnormal, ldexp will cause double rounding because subnormals have fewer mantisa bits.
389
+ # As such, we need to check whether round to even would lead to double rounding and manually round sumhi to avoid it.
390
+ if issubnormal (ldexp (sumhi, bias))
391
+ sumlo = r- sumhi+ s
392
+ bits_lost = - bias- exponent (sumhi)- 1022
393
+ sumhiInt = reinterpret (UInt64, sumhi)
394
+ if (bits_lost != 1 ) ⊻ (sumhiInt& 1 == 1 )
395
+ sumhi = nextfloat (sumhi, cmp (sumlo,0 ))
396
+ end
397
+ end
398
+ return ldexp (sumhi, bias)
399
+ end
400
+ isinf (abhi) && signbit (c) == signbit (a* b) && return abhi
401
+ # fall through
402
+ end
403
+ r = abhi+ c
404
+ s = (abs (abhi) > abs (c)) ? (abhi- r+ c+ ablo) : (c- r+ abhi+ ablo)
405
+ return r+ s
406
+ end
350
407
fma_llvm (x:: Float32 , y:: Float32 , z:: Float32 ) = fma_float (x, y, z)
351
408
fma_llvm (x:: Float64 , y:: Float64 , z:: Float64 ) = fma_float (x, y, z)
352
409
# Disable LLVM's fma if it is incorrect, e.g. because LLVM falls back
353
- # onto a broken system libm; if so, use openlibm's fma instead
410
+ # onto a broken system libm; if so, use a software emulated fma
354
411
# 1.0000305f0 = 1 + 1/2^15
355
412
# 1.0000000009313226 = 1 + 1/2^30
356
413
# If fma_llvm() clobbers the rounding mode, the result of 0.1 + 0.2 will be 0.3
357
414
# instead of the properly-rounded 0.30000000000000004; check after calling fma
415
+ # TODO actually detect fma in hardware and switch on that.
358
416
if (Sys. ARCH != = :i686 && fma_llvm (1.0000305f0 , 1.0000305f0 , - 1.0f0 ) == 6.103609f-5 &&
359
417
(fma_llvm (1.0000000009313226 , 1.0000000009313226 , - 1.0 ) ==
360
418
1.8626451500983188e-9 ) && 0.1 + 0.2 == 0.30000000000000004 )
361
419
fma (x:: Float32 , y:: Float32 , z:: Float32 ) = fma_llvm (x,y,z)
362
420
fma (x:: Float64 , y:: Float64 , z:: Float64 ) = fma_llvm (x,y,z)
363
421
else
364
- fma (x:: Float32 , y:: Float32 , z:: Float32 ) = fma_libm (x,y,z)
365
- fma (x:: Float64 , y:: Float64 , z:: Float64 ) = fma_libm (x,y,z)
422
+ fma (x:: Float32 , y:: Float32 , z:: Float32 ) = fma_emulated (x,y,z)
423
+ fma (x:: Float64 , y:: Float64 , z:: Float64 ) = fma_emulated (x,y,z)
366
424
end
367
425
function fma (a:: Float16 , b:: Float16 , c:: Float16 )
368
- Float16 (fma (Float32 (a), Float32 (b), Float32 (c)))
426
+ Float16 (muladd (Float32 (a), Float32 (b), Float32 (c))) # don't use fma if the hardware doesn't have it.
369
427
end
370
428
371
429
# This is necessary at least on 32-bit Intel Linux, since fma_llvm may
0 commit comments