Skip to content

Commit fdc5a72

Browse files
committed
Revert "Merge pull request #569 from JuliaDiff/kc/revert_simd"
This reverts commit 43ef860, reversing changes made to 01a056d.
1 parent 9c2be97 commit fdc5a72

File tree

6 files changed

+56
-23
lines changed

6 files changed

+56
-23
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ jobs:
1616
fail-fast: false
1717
matrix:
1818
version:
19-
- '1.0'
2019
- '1'
2120
- 'nightly'
2221
os:

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1212
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1313
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1516
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1617
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718

@@ -24,9 +25,11 @@ DiffTests = "0.0.1, 0.1"
2425
LogExpFunctions = "0.3"
2526
NaNMath = "0.2.2, 0.3"
2627
Preferences = "1"
28+
SIMD = "3"
2729
SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0, 2"
30+
SIMD = "3"
2831
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
29-
julia = "1"
32+
julia = "1.6"
3033

3134
[extras]
3235
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"

src/ForwardDiff.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,21 @@ if VERSION >= v"1.6"
88
end
99
using Random
1010
using LinearAlgebra
11+
import SIMD: Vec
1112

1213
import Printf
1314
import NaNMath
1415
import SpecialFunctions
1516
import LogExpFunctions
1617
import CommonSubexpressions
1718

19+
const SIMDFloat = Union{Float64, Float32}
20+
const SIMDInt = Union{
21+
Int128, Int64, Int32, Int16, Int8,
22+
UInt128, UInt64, UInt32, UInt16, UInt8,
23+
}
24+
const SIMDType = Union{SIMDFloat, SIMDInt}
25+
1826
include("prelude.jl")
1927
include("partials.jl")
2028
include("dual.jl")

src/dual.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,16 @@ end
541541
# fma #
542542
#-----#
543543

544+
@inline function calc_fma_xyz(x::Dual{T,V,N},
545+
y::Dual{T,V,N},
546+
z::Dual{T,V,N}) where {T, V<:SIMDFloat,N}
547+
xv, yv, zv = value(x), value(y), value(z)
548+
rv = fma(xv, yv, zv)
549+
N == 0 && return Dual{T}(rv)
550+
xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values)
551+
parts = Tuple(fma(xv, yp, fma(yv, xp, zp)))
552+
Dual{T}(rv, parts)
553+
end
544554
@generated function calc_fma_xyz(x::Dual{T,<:Any,N},
545555
y::Dual{T,<:Any,N},
546556
z::Dual{T,<:Any,N}) where {T,N}
@@ -583,6 +593,16 @@ end
583593
# muladd #
584594
#--------#
585595

596+
@inline function calc_muladd_xyz(x::Dual{T,V,N},
597+
y::Dual{T,V,N},
598+
z::Dual{T,V,N}) where {T, V<:SIMDType,N}
599+
xv, yv, zv = value(x), value(y), value(z)
600+
rv = muladd(xv, yv, zv)
601+
N == 0 && return Dual{T}(rv)
602+
xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values)
603+
parts = Tuple(muladd(xv, yp, muladd(yv, xp, zp)))
604+
Dual{T}(rv, parts)
605+
end
586606
@generated function calc_muladd_xyz(x::Dual{T,<:Any,N},
587607
y::Dual{T,<:Any,N},
588608
z::Dual{T,<:Any,N}) where {T,N}

src/partials.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ end
141141
@inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b
142142
@inline _mul_partials(a::Partials{N,A}, b::Partials{0,B}, afactor, bfactor) where {N,A,B} = afactor * a
143143

144+
const SIMDFloat = Union{Float64, Float32}
145+
const SIMDInt = Union{
146+
Int128, Int64, Int32, Int16, Int8,
147+
UInt128, UInt64, UInt32, UInt16, UInt8,
148+
}
149+
const SIMDType = Union{SIMDFloat, SIMDInt}
150+
144151
##################################
145152
# Generated Functions on NTuples #
146153
##################################
@@ -164,6 +171,7 @@ end
164171
@inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple()
165172
@inline rand_tuple(::Type{Tuple{}}) = tuple()
166173

174+
iszero_tuple(tup::NTuple{N,V}) where {N, V<:SIMDType} = sum(Vec(tup) != zero(V)) == 0
167175
@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V}
168176
ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...)
169177
return quote
@@ -197,29 +205,24 @@ end
197205
return tupexpr(i -> :(rand(V)), N)
198206
end
199207

200-
@generated function scale_tuple(tup::NTuple{N}, x) where N
201-
return tupexpr(i -> :(tup[$i] * x), N)
202-
end
208+
const NT{N,T} = NTuple{N,T}
203209

204-
@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
205-
return tupexpr(i -> :(tup[$i] / x), N)
206-
end
210+
# SIMD implementation
211+
@inline add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b))
212+
@inline sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b))
213+
@inline scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x)
214+
@inline div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x)
215+
@inline minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup))
216+
@inline mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(af, Vec(a), bf * Vec(b)))
207217

208-
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
209-
return tupexpr(i -> :(a[$i] + b[$i]), N)
210-
end
211218

212-
@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N
213-
return tupexpr(i -> :(a[$i] - b[$i]), N)
214-
end
215-
216-
@generated function minus_tuple(tup::NTuple{N}) where N
217-
return tupexpr(i -> :(-tup[$i]), N)
218-
end
219-
220-
@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
221-
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
222-
end
219+
# Fallback implementations
220+
@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] + b[$i]), N)
221+
@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] - b[$i]), N)
222+
@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N)
223+
@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N)
224+
@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N)
225+
@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :(muladd(af, a[$i], bf * b[$i])), N)
223226

224227
###################
225228
# Pretty Printing #

test/PartialsTest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ for N in (0, 3), T in (Int, Float32, Float64)
114114

115115
if N > 0
116116
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
117-
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
117+
@test all(isapprox.(ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values, map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)))
118118
@test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS
119119
@test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS
120120

0 commit comments

Comments
 (0)