Skip to content

Commit ae06e73

Browse files
committed
fix some issues with equality of factorizations
- `hash` did not respect the type of a factorization, so completely different factorizations with the same underlying data would result in same `hash` leading to inconsistencies with `isequal`. This likely doesn't occur very often in practice, but definitely seems worth fixing. - `==` and `isequal` only returned true if two factorizations are of exactly the same type, which is inconsistent with their implementation for other objects and with the definition of `hash` for factorizations. - Equality for `QRCompactWY` did not ignore the subdiagonal entries of `T` leading to nondeterministic behavior. Perhaps `T` should be directly stored as `UpperTriangular` in `QRCompactWY`, but that seems potentially breaking. Relying on implementation details of `DataType` here is certainly less than ideal, but I could not come up with a nicer solution.
1 parent 4a81b08 commit ae06e73

File tree

4 files changed

+65
-3
lines changed

4 files changed

+65
-3
lines changed

stdlib/LinearAlgebra/src/factorization.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,17 @@ Factorization{T}(A::Adjoint{<:Any,<:Factorization}) where {T} =
6464
adjoint(Factorization{T}(parent(A)))
6565
inv(F::Factorization{T}) where {T} = (n = size(F, 1); ldiv!(F, Matrix{T}(I, n, n)))
6666

67-
Base.hash(F::Factorization, h::UInt) = mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=h)
68-
Base.:(==)( F::T, G::T) where {T<:Factorization} = all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F))
69-
Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool
67+
function Base.hash(F::Factorization, h::UInt)
68+
return mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=hash(typeof(F).name.wrapper, h))
69+
end
70+
function Base.:(==)(F::Factorization, G::Factorization)
71+
typeof(F).name.wrapper == typeof(G).name.wrapper || return false
72+
return all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F))
73+
end
74+
function Base.isequal(F::Factorization, G::Factorization)
75+
typeof(F).name.wrapper == typeof(G).name.wrapper || return false
76+
return all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool
77+
end
7078

7179
function Base.show(io::IO, x::Adjoint{<:Any,<:Factorization})
7280
print(io, "Adjoint of ")

stdlib/LinearAlgebra/src/qr.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R))
127127
Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done))
128128
Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing
129129

130+
function Base.hash(F::QRCompactWY, h::UInt)
131+
return hash(F.factors, hash(UpperTriangular(F.T), hash(QRCompactWY, h)))
132+
end
133+
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
134+
return A.factors == B.factors && UpperTriangular(A.T) == UpperTriangular(B.T)
135+
end
136+
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
137+
return isequal(A.factors, B.factors) && isequal(UpperTriangular(A.T), UpperTriangular(B.T))
138+
end
139+
130140
"""
131141
QRPivoted <: Factorization
132142
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
@testset "equality for factorizations - $f" for f in Any[
2+
bunchkaufman,
3+
cholesky,
4+
x -> cholesky(x, Val(true)),
5+
eigen,
6+
hessenberg,
7+
lq,
8+
lu,
9+
qr,
10+
x -> qr(x, ColumnNorm()),
11+
svd,
12+
schur,
13+
]
14+
A = randn(3, 3)
15+
A = A * A' # ensure A is pos. def. and symmetric
16+
F, G = f(A), f(A)
17+
18+
@test F == G
19+
@test isequal(F, G)
20+
@test hash(F) == hash(G)
21+
22+
f === hessenberg && continue
23+
24+
F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i
25+
x = getfield(F, i)
26+
return x isa AbstractArray{Float64} ? Float32.(x) : x
27+
end...)
28+
G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i
29+
x = getfield(G, i)
30+
return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x
31+
end...)
32+
33+
@test F == G
34+
@test isequal(F, G)
35+
@test hash(F) == hash(G)
36+
end
37+
38+
@testset "hash collisions" begin
39+
A, v = randn(2, 2), randn(2)
40+
F, G = LQ(A, v), QR(A, v)
41+
@test !isequal(F, G)
42+
@test hash(F) != hash(G)
43+
end

stdlib/LinearAlgebra/test/testgroups

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ givens
2525
structuredbroadcast
2626
addmul
2727
ldlt
28+
factorization

0 commit comments

Comments
 (0)