Skip to content

Commit 7d0b12c

Browse files
Turn StaticArrays into an Extension Pacakge
omg so fast! no requires on v1.8 extension typo fix imports fix all tests fix the backwards compat with no requires add missing ones tag overload valtype add partials Update ext/ForwardDiffStaticArraysExt.jl Co-authored-by: David Widmann <[email protected]> a few more imports finish the migration
1 parent eb8d755 commit 7d0b12c

File tree

8 files changed

+146
-124
lines changed

8 files changed

+146
-124
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,19 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2"
2828
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
2929
julia = "1"
3030

31+
[extensions]
32+
ForwardDiffStaticArraysExt = "StaticArrays"
33+
3134
[extras]
3235
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3336
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
3437
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3538
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
39+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3640
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3741

3842
[targets]
39-
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"]
43+
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils", "StaticArrays"]
44+
45+
[weakdeps]
46+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
module ForwardDiffStaticArraysExt
2+
3+
using ForwardDiff
4+
using ForwardDiff.DiffResults: DiffResults, DiffResult, ImmutableDiffResult, MutableDiffResult
5+
using ForwardDiff.LinearAlgebra
6+
import ForwardDiff: Dual, Chunk, value, extract_jacobian!, extract_value!, extract_gradient!, extract_jacobian!,
7+
GradientConfig, JacobianConfig, HessianConfig, vector_mode_gradient, vector_mode_gradient!,
8+
Tag, valtype, partials, gradient, gradient!, jacobian, jacobian!, hessian, hessian!, vector_mode_jacobian,
9+
vector_mode_jacobian!
10+
using StaticArrays
11+
12+
@generated function dualize(::Type{T}, x::StaticArray) where T
13+
N = length(x)
14+
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
15+
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N})
16+
return quote
17+
chunk = Chunk{$N}()
18+
$(Expr(:meta, :inline))
19+
return $V($(dx))
20+
end
21+
end
22+
23+
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
24+
25+
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
26+
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
27+
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
28+
29+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
30+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
31+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
32+
33+
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
34+
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
35+
return quote
36+
$(Expr(:meta, :inline))
37+
V = StaticArrays.similar_type(S, valtype($y))
38+
return V($result)
39+
end
40+
end
41+
42+
@inline function ForwardDiff.vector_mode_gradient(f, x::StaticArray)
43+
T = typeof(Tag(f, eltype(x)))
44+
return extract_gradient(T, static_dual_eval(T, f, x), x)
45+
end
46+
47+
@inline function ForwardDiff.vector_mode_gradient!(result, f, x::StaticArray)
48+
T = typeof(Tag(f, eltype(x)))
49+
return extract_gradient!(T, result, static_dual_eval(T, f, x))
50+
end
51+
52+
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
53+
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
54+
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
55+
56+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
57+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
58+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
59+
60+
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
61+
M, N = length(ydual), length(x)
62+
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
63+
return quote
64+
$(Expr(:meta, :inline))
65+
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
66+
return V($result)
67+
end
68+
end
69+
70+
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
71+
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
72+
return extract_jacobian!(T, result, ydual, length(x))
73+
end
74+
75+
@inline function ForwardDiff.vector_mode_jacobian(f, x::StaticArray)
76+
T = typeof(Tag(f, eltype(x)))
77+
return extract_jacobian(T, static_dual_eval(T, f, x), x)
78+
end
79+
80+
@inline function ForwardDiff.vector_mode_jacobian!(result, f, x::StaticArray)
81+
T = typeof(Tag(f, eltype(x)))
82+
ydual = static_dual_eval(T, f, x)
83+
result = extract_jacobian!(T, result, ydual, length(x))
84+
result = extract_value!(T, result, ydual)
85+
return result
86+
end
87+
88+
@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
89+
T = typeof(Tag(f, eltype(x)))
90+
ydual = static_dual_eval(T, f, x)
91+
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
92+
result = DiffResults.value!(d -> value(T,d), result, ydual)
93+
return result
94+
end
95+
96+
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
97+
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
98+
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
99+
100+
ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
101+
102+
ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
103+
104+
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
105+
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
106+
107+
function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
108+
T = typeof(Tag(f, eltype(x)))
109+
d1 = dualize(T, x)
110+
d2 = dualize(T, d1)
111+
fd2 = f(d2)
112+
val = value(T,value(T,fd2))
113+
grad = extract_gradient(T,value(T,fd2), x)
114+
hess = extract_jacobian(T,partials(T,fd2), x)
115+
result = DiffResults.hessian!(result, hess)
116+
result = DiffResults.gradient!(result, grad)
117+
result = DiffResults.value!(result, val)
118+
return result
119+
end
120+
121+
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
122+
λ,Q = eigen(Symmetric(value.(parent(A))))
123+
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
124+
Dual{Tg}.(λ, tuple.(parts...))
125+
end
126+
127+
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
128+
λ = eigvals(A)
129+
_,Q = eigen(Symmetric(value.(parent(A))))
130+
parts = ntuple(j -> Q*ForwardDiff._lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
131+
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
132+
end
133+
134+
end

src/ForwardDiff.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module ForwardDiff
22

33
using DiffRules, DiffResults
44
using DiffResults: DiffResult, MutableDiffResult, ImmutableDiffResult
5-
using StaticArrays
65
if VERSION >= v"1.6"
76
using Preferences
87
end
@@ -25,6 +24,10 @@ include("gradient.jl")
2524
include("jacobian.jl")
2625
include("hessian.jl")
2726

27+
if !isdefined(Base, :get_extension)
28+
include("../ext/ForwardDiffStaticArraysExt.jl")
29+
end
30+
2831
export DiffResults
2932

3033
end # module

src/apiutils.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,6 @@ end
1818
# vector mode function evaluation #
1919
###################################
2020

21-
@generated function dualize(::Type{T}, x::StaticArray) where T
22-
N = length(x)
23-
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
24-
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N})
25-
return quote
26-
chunk = Chunk{$N}()
27-
$(Expr(:meta, :inline))
28-
return $V($(dx))
29-
end
30-
end
31-
32-
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
33-
3421
function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F}
3522
xdual = cfg.duals
3623
seed!(xdual, x, cfg.seeds)

src/dual.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -708,12 +708,6 @@ function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N
708708
Dual{Tg}.(λ, tuple.(parts...))
709709
end
710710

711-
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
712-
λ,Q = eigen(Symmetric(value.(parent(A))))
713-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
714-
Dual{Tg}.(λ, tuple.(parts...))
715-
end
716-
717711
function LinearAlgebra.eigvals(A::Hermitian{<:Complex{<:Dual{Tg,T,N}}}) where {Tg,T<:Real,N}
718712
λ,Q = eigen(Hermitian(value.(real.(parent(A))) .+ im .* value.(imag.(parent(A)))))
719713
parts = ntuple(j -> diag(real.(Q' * (getindex.(partials.(real.(A)) .+ im .* partials.(imag.(A)), j)) * Q)), N)
@@ -743,13 +737,6 @@ function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
743737
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
744738
end
745739

746-
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
747-
λ = eigvals(A)
748-
_,Q = eigen(Symmetric(value.(parent(A))))
749-
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
750-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
751-
end
752-
753740
function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
754741
λ = eigvals(A)
755742
_,Q = eigen(SymTridiagonal(value.(parent(A))))

src/gradient.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,12 @@ function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArr
4141
return result
4242
end
4343

44-
@inline gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
45-
@inline gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
46-
@inline gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
47-
48-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
49-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
50-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
51-
5244
gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
5345

5446
#####################
5547
# result extraction #
5648
#####################
5749

58-
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
59-
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
60-
return quote
61-
$(Expr(:meta, :inline))
62-
V = StaticArrays.similar_type(S, valtype($y))
63-
return V($result)
64-
end
65-
end
66-
6750
function extract_gradient!(::Type{T}, result::DiffResult, y::Real) where {T}
6851
result = DiffResults.value!(result, y)
6952
grad = DiffResults.gradient(result)
@@ -115,16 +98,6 @@ function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T
11598
return result
11699
end
117100

118-
@inline function vector_mode_gradient(f, x::StaticArray)
119-
T = typeof(Tag(f, eltype(x)))
120-
return extract_gradient(T, static_dual_eval(T, f, x), x)
121-
end
122-
123-
@inline function vector_mode_gradient!(result, f, x::StaticArray)
124-
T = typeof(Tag(f, eltype(x)))
125-
return extract_gradient!(T, result, static_dual_eval(T, f, x))
126-
end
127-
128101
##############
129102
# chunk mode #
130103
##############

src/hessian.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,3 @@ function hessian!(result::DiffResult, f, x::AbstractArray, cfg::HessianConfig{T}
6767
jacobian!(DiffResults.hessian(result), ∇f!, DiffResults.gradient(result), x, cfg.jacobian_config, Val{false}())
6868
return ∇f!.result
6969
end
70-
71-
hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
72-
hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
73-
hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
74-
75-
hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
76-
77-
hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
78-
79-
hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
80-
hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
81-
82-
function hessian!(result::ImmutableDiffResult, f, x::StaticArray)
83-
T = typeof(Tag(f, eltype(x)))
84-
d1 = dualize(T, x)
85-
d2 = dualize(T, d1)
86-
fd2 = f(d2)
87-
val = value(T,value(T,fd2))
88-
grad = extract_gradient(T,value(T,fd2), x)
89-
hess = extract_jacobian(T,partials(T,fd2), x)
90-
result = DiffResults.hessian!(result, hess)
91-
result = DiffResults.gradient!(result, grad)
92-
result = DiffResults.value!(result, val)
93-
return result
94-
end

src/jacobian.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,35 +82,12 @@ function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray
8282
return result
8383
end
8484

85-
@inline jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
86-
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
87-
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
88-
89-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
90-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
91-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
92-
9385
jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
9486

9587
#####################
9688
# result extraction #
9789
#####################
9890

99-
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
100-
M, N = length(ydual), length(x)
101-
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
102-
return quote
103-
$(Expr(:meta, :inline))
104-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
105-
return V($result)
106-
end
107-
end
108-
109-
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
110-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
111-
return extract_jacobian!(T, result, ydual, length(x))
112-
end
113-
11491
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
11592
out_reshaped = reshape(result, length(ydual), n)
11693
ydual_reshaped = vec(ydual)
@@ -180,27 +157,6 @@ function vector_mode_jacobian!(result, f!::F, y, x, cfg::JacobianConfig{T}) wher
180157
return result
181158
end
182159

183-
@inline function vector_mode_jacobian(f, x::StaticArray)
184-
T = typeof(Tag(f, eltype(x)))
185-
return extract_jacobian(T, static_dual_eval(T, f, x), x)
186-
end
187-
188-
@inline function vector_mode_jacobian!(result, f, x::StaticArray)
189-
T = typeof(Tag(f, eltype(x)))
190-
ydual = static_dual_eval(T, f, x)
191-
result = extract_jacobian!(T, result, ydual, length(x))
192-
result = extract_value!(T, result, ydual)
193-
return result
194-
end
195-
196-
@inline function vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
197-
T = typeof(Tag(f, eltype(x)))
198-
ydual = static_dual_eval(T, f, x)
199-
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
200-
result = DiffResults.value!(d -> value(T,d), result, ydual)
201-
return result
202-
end
203-
204160
const JACOBIAN_ERROR = DimensionMismatch("jacobian(f, x) expects that f(x) is an array. Perhaps you meant gradient(f, x)?")
205161

206162
# chunk mode #

0 commit comments

Comments
 (0)