Skip to content

Commit b519518

Browse files
committed
disambiguation
1 parent 01ae5d9 commit b519518

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
316316
return C
317317
end
318318

319-
#TODO: many of /, \ related function has no size check and singular check
319+
#TODO: many of /, \ related function has no singular check
320320
(/)(A::AbstractVecOrMat, D::Diagonal) =
321321
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
322322
(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)
@@ -345,6 +345,7 @@ end
345345
(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
346346
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)
347347

348+
#TODO: we should check size(x,2) == size(b,2)
348349
ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b)
349350

350351
function ldiv!(D::Diagonal, B::AbstractVecOrMat)
@@ -548,11 +549,17 @@ function svd(D::Diagonal{<:Number})
548549
return SVD(Up, S[piv], copy(Vp'))
549550
end
550551

551-
# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
552-
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
553-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
554-
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
555-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
552+
# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec
553+
*(x::AdjointAbsVec, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
554+
*(x::TransposeAbsVec, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
555+
*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
556+
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
557+
/(u::AdjointAbsVec, D::Diagonal) = adjoint(adjoint(D) \ u.parent)
558+
/(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) \ u.parent)
559+
# disambiguation methods: Call unoptimized version for user defined AbstractTriangular.
560+
*(A::AbstractTriangular, D::Diagonal) = Base.@invoke *(A::AbstractMatrix, D::Diagonal)
561+
*(D::Diagonal, A::AbstractTriangular) = Base.@invoke *(D::Diagonal, A::AbstractMatrix)
562+
556563
dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)
557564

558565
dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
@@ -620,4 +627,4 @@ end
620627

621628
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
622629
Diagonal(A.diag .* B.diag .+ z.diag)
623-
end
630+
end

0 commit comments

Comments
 (0)