Skip to content

Commit ee36c13

Browse files
authored
Emulated fma (#42783)
Emulated `fma` for cases when hardware fma is not available. Generally pre-Haswell, some arm, etc. Co-authored-by: oscarddssmith <[email protected]>
1 parent e7df4a6 commit ee36c13

File tree

2 files changed

+90
-8
lines changed

2 files changed

+90
-8
lines changed

base/floatfuncs.jl

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -342,30 +342,88 @@ significantly more expensive than `x*y+z`. `fma` is used to improve accuracy in
342342
algorithms. See [`muladd`](@ref).
343343
"""
344344
function fma end
345+
function fma_emulated(a::Float32, b::Float32, c::Float32)::Float32
346+
ab = Float64(a) * b
347+
res = ab+c
348+
reinterpret(UInt64, res)&0x1fff_ffff!=0x1000_0000 && return res
349+
# yes error compensation is necessary. It sucks
350+
reslo = abs(c)>abs(ab) ? ab-(res - c) : c-(res - ab)
351+
res = iszero(reslo) ? res : (signbit(reslo) ? prevfloat(res) : nextfloat(res))
352+
return res
353+
end
354+
355+
""" Splits a Float64 into a hi bit and a low bit where the high bit has 27 trailing 0s and the low bit has 26 trailing 0s"""
356+
@inline function splitbits(x::Float64)
357+
hi = reinterpret(Float64, reinterpret(UInt64, x) & 0xffff_ffff_f800_0000)
358+
return hi, x-hi
359+
end
360+
361+
@inline function twomul(a::Float64, b::Float64)
362+
ahi, alo = splitbits(a)
363+
bhi, blo = splitbits(b)
364+
abhi = a*b
365+
blohi, blolo = splitbits(blo)
366+
ablo = alo*blohi - (((abhi - ahi*bhi) - alo*bhi) - ahi*blo) + blolo*alo
367+
return abhi, ablo
368+
end
345369

346-
fma_libm(x::Float32, y::Float32, z::Float32) =
347-
ccall(("fmaf", libm_name), Float32, (Float32,Float32,Float32), x, y, z)
348-
fma_libm(x::Float64, y::Float64, z::Float64) =
349-
ccall(("fma", libm_name), Float64, (Float64,Float64,Float64), x, y, z)
370+
function fma_emulated(a::Float64, b::Float64,c::Float64)
371+
abhi, ablo = twomul(a,b)
372+
if !isfinite(abhi+c) || isless(abs(abhi), nextfloat(0x1p-969)) || issubnormal(a) || issubnormal(b)
373+
(isfinite(a) && isfinite(b) && isfinite(c)) || return abhi+c
374+
(iszero(a) || iszero(b)) && return abhi+c
375+
bias = exponent(a) + exponent(b)
376+
c_denorm = ldexp(c, -bias)
377+
if isfinite(c_denorm)
378+
# rescale a and b to [1,2), equivalent to ldexp(a, -exponent(a))
379+
issubnormal(a) && (a *= 0x1p52)
380+
issubnormal(b) && (b *= 0x1p52)
381+
a = reinterpret(Float64, (reinterpret(UInt64, a) & 0x800fffffffffffff) | 0x3ff0000000000000)
382+
b = reinterpret(Float64, (reinterpret(UInt64, b) & 0x800fffffffffffff) | 0x3ff0000000000000)
383+
c = c_denorm
384+
abhi, ablo = twomul(a,b)
385+
r = abhi+c
386+
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
387+
sumhi = r+s
388+
# If result is subnormal, ldexp will cause double rounding because subnormals have fewer mantisa bits.
389+
# As such, we need to check whether round to even would lead to double rounding and manually round sumhi to avoid it.
390+
if issubnormal(ldexp(sumhi, bias))
391+
sumlo = r-sumhi+s
392+
bits_lost = -bias-exponent(sumhi)-1022
393+
sumhiInt = reinterpret(UInt64, sumhi)
394+
if (bits_lost != 1) (sumhiInt&1 == 1)
395+
sumhi = nextfloat(sumhi, cmp(sumlo,0))
396+
end
397+
end
398+
return ldexp(sumhi, bias)
399+
end
400+
isinf(abhi) && signbit(c) == signbit(a*b) && return abhi
401+
# fall through
402+
end
403+
r = abhi+c
404+
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
405+
return r+s
406+
end
350407
fma_llvm(x::Float32, y::Float32, z::Float32) = fma_float(x, y, z)
351408
fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)
352409
# Disable LLVM's fma if it is incorrect, e.g. because LLVM falls back
353-
# onto a broken system libm; if so, use openlibm's fma instead
410+
# onto a broken system libm; if so, use a software emulated fma
354411
# 1.0000305f0 = 1 + 1/2^15
355412
# 1.0000000009313226 = 1 + 1/2^30
356413
# If fma_llvm() clobbers the rounding mode, the result of 0.1 + 0.2 will be 0.3
357414
# instead of the properly-rounded 0.30000000000000004; check after calling fma
415+
# TODO actually detect fma in hardware and switch on that.
358416
if (Sys.ARCH !== :i686 && fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 &&
359417
(fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) ==
360418
1.8626451500983188e-9) && 0.1 + 0.2 == 0.30000000000000004)
361419
fma(x::Float32, y::Float32, z::Float32) = fma_llvm(x,y,z)
362420
fma(x::Float64, y::Float64, z::Float64) = fma_llvm(x,y,z)
363421
else
364-
fma(x::Float32, y::Float32, z::Float32) = fma_libm(x,y,z)
365-
fma(x::Float64, y::Float64, z::Float64) = fma_libm(x,y,z)
422+
fma(x::Float32, y::Float32, z::Float32) = fma_emulated(x,y,z)
423+
fma(x::Float64, y::Float64, z::Float64) = fma_emulated(x,y,z)
366424
end
367425
function fma(a::Float16, b::Float16, c::Float16)
368-
Float16(fma(Float32(a), Float32(b), Float32(c)))
426+
Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it.
369427
end
370428

371429
# This is necessary at least on 32-bit Intel Linux, since fma_llvm may

test/math.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,3 +1286,27 @@ end
12861286
@test_throws MethodError f(x)
12871287
end
12881288
end
1289+
1290+
@testset "fma" begin
1291+
for func in (fma, Base.fma_emulated)
1292+
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
1293+
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
1294+
@testset "$T" for T in (Float32, Float64)
1295+
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
1296+
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
1297+
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
1298+
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
1299+
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
1300+
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
1301+
for _ in 1:2^18
1302+
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
1303+
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
1304+
end
1305+
end
1306+
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
1307+
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
1308+
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
1309+
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
1310+
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
1311+
end
1312+
end

0 commit comments

Comments
 (0)