Skip to content

Commit 958711d

Browse files
simonbyrnejrevels
authored andcommitted
reduce occurences of tag errors (#389)
1 parent 52f4be7 commit 958711d

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/dual.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,12 @@ end
8282

8383
@inline value(::Type{T}, x) where T = x
8484
@inline value(::Type{T}, d::Dual{T}) where T = value(d)
85-
function value(::Type{T}, d::Dual{S}) where {T,S}
86-
# TODO: in the case of nested Duals, it may be possible to "transpose" the Dual objects
87-
throw(DualMismatchError(T,S))
85+
@inline function value(::Type{T}, d::Dual{S}) where {T,S}
86+
if S T
87+
d
88+
else
89+
throw(DualMismatchError(T,S))
90+
end
8891
end
8992

9093
@inline partials(x) = Partials{0,typeof(x)}(tuple())
@@ -96,7 +99,14 @@ end
9699

97100
@inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...)
98101
@inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...)
99-
partials(::Type{T}, d::Dual{S}, i...) where {T,S} = throw(DualMismatchError(T,S))
102+
@inline function partials(::Type{T}, d::Dual{S}, i...) where {T,S}
103+
if S T
104+
zero(d)
105+
else
106+
throw(DualMismatchError(T,S))
107+
end
108+
end
109+
100110

101111
@inline npartials(::Dual{T,V,N}) where {T,V,N} = N
102112
@inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N

test/ConfusionTest.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,11 @@ let z = z83a
6666
@test ForwardDiff.hessian(h, [1.]) == zeros(1, 1)
6767
end
6868

69+
@test ForwardDiff.derivative(1.0) do x
70+
ForwardDiff.derivative(x) do y
71+
x
72+
end
73+
end == 0.0
74+
75+
6976
end # module

0 commit comments

Comments
 (0)