Skip to content

Commit e925450

Browse files
authored
Merge pull request #19 from N5N3/N5N3-patch-1
N5 n3 patch 1
2 parents b3b6355 + a59ee72 commit e925450

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
317317
return C
318318
end
319319

320-
(/)(A::AbstractVecOrMat, D::Diagonal) =
321-
_rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D)
320+
_promote_dotop(f, args...) = promote_op(f, eltype.(args)...)
321+
322+
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _promote_dotop(/, A, D), size(A)), A, D)
322323

323324
rdiv!(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(A, A, D)
324325
# avoid copy when possible via internal 3-arg backend
@@ -338,21 +339,10 @@ function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
338339
end
339340
B
340341
end
341-
# Optimization for Diagonal / Diagonal
342-
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
343-
n, k = length(Db.diag), length(Db.diag)
344-
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
345-
j = findfirst(iszero, Da.diag)
346-
isnothing(j) || throw(SingularException(j))
347-
Dc.diag .= Db.diag ./ Da.diag
348-
Dc
349-
end
350342

351-
(\)(D::Diagonal, B::AbstractVecOrMat) =
352-
ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B)
343+
\(D::Diagonal, B::AbstractVecOrMat) = ldiv!(similar(B, _promote_dotop(\, D, B), size(B)), D, B)
353344

354345
ldiv!(D::Diagonal, B::AbstractVecOrMat) = ldiv!(B, D, B)
355-
ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag))
356346
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
357347
require_one_based_indexing(A, B)
358348
d = length(D.diag)
@@ -365,6 +355,19 @@ function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
365355
B .= D.diag .\ A
366356
end
367357

358+
# Optimizations for \, / between Diagonals
359+
\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, _promote_dotop(\, D, B)), D, B)
360+
/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, _promote_dotop(/, A, D)), A, D)
361+
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
362+
n, k = length(Db.diag), length(Db.diag)
363+
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
364+
j = findfirst(iszero, Da.diag)
365+
isnothing(j) || throw(SingularException(j))
366+
Dc.diag .= Db.diag ./ Da.diag
367+
Dc
368+
end
369+
ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag))
370+
368371
# (l/r)mul!, l/rdiv!, *, / and \ Optimization for AbstractTriangular.
369372
# These functions are generally more efficient if we calculate the whole data field.
370373
# The following code implements them in a unified patten to avoid missing.

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ Random.seed!(1)
204204
@test D2*D' Array(D2)*Array(D)'
205205

206206
#division of two Diagonals
207-
@test D/D2 Diagonal(D.diag./D2.diag)
208-
@test D\D2 Diagonal(D2.diag./D.diag)
207+
@test (D/D2)::Diagonal Diagonal(D.diag./D2.diag)
208+
@test (D\D2)::Diagonal Diagonal(D2.diag./D.diag)
209209

210210
# QR \ Diagonal
211211
A = rand(elty, n, n)

stdlib/LinearAlgebra/test/hessenberg.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ let n = 10
9090
@testset "Multiplication/division" begin
9191
for x = (5, 5I, Diagonal(d), Bidiagonal(d,dl,:U),
9292
UpperTriangular(A), UnitUpperTriangular(A))
93-
@test H*x == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal
94-
@test x*H == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal
93+
@test (H*x)::UpperHessenberg == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal
94+
@test (x*H)::UpperHessenberg == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal
9595
@test H/x == Array(H)/x broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular}
9696
@test x\H == x\Array(H) broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular}
97-
@test H*x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
98-
@test x*H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
9997
@test H/x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
10098
@test x\H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
10199
end

0 commit comments

Comments
 (0)