Skip to content

Commit 03af781

Browse files
authored
Fix performance issue with diagonal multiplication (#44651)
1 parent 1a7355b commit 03af781

File tree

2 files changed

+65
-37
lines changed

2 files changed

+65
-37
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ end
249249
(*)(D::Diagonal, A::AbstractMatrix) =
250250
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
251251

252-
rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
253-
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)
252+
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
253+
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
254254

255255
#TODO: It seems better to call (D' * adjA')' directly?
256256
function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal)
@@ -285,35 +285,80 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
285285
end
286286

287287
@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
288-
if iszero(beta)
289-
out .= (D.diag .* B) .*ₛ alpha
288+
require_one_based_indexing(out)
289+
if iszero(alpha)
290+
_rmul_or_fill!(out, beta)
290291
else
291-
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
292+
if iszero(beta)
293+
@inbounds for j in axes(B, 2)
294+
@simd for i in axes(B, 1)
295+
out[i,j] = D.diag[i] * B[i,j] * alpha
296+
end
297+
end
298+
else
299+
@inbounds for j in axes(B, 2)
300+
@simd for i in axes(B, 1)
301+
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
302+
end
303+
end
304+
end
292305
end
293306
return out
294307
end
295-
296308
@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
297-
if iszero(beta)
298-
out .= (A .* permutedims(D.diag)) .*ₛ alpha
309+
require_one_based_indexing(out)
310+
if iszero(alpha)
311+
_rmul_or_fill!(out, beta)
299312
else
300-
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
313+
if iszero(beta)
314+
@inbounds for j in axes(A, 2)
315+
dja = D.diag[j] * alpha
316+
@simd for i in axes(A, 1)
317+
out[i,j] = A[i,j] * dja
318+
end
319+
end
320+
else
321+
@inbounds for j in axes(A, 2)
322+
dja = D.diag[j] * alpha
323+
@simd for i in axes(A, 1)
324+
out[i,j] = A[i,j] * dja + out[i,j] * beta
325+
end
326+
end
327+
end
301328
end
302329
return out
303330
end
304-
305331
@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
306-
if iszero(beta)
307-
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
332+
d1 = D1.diag
333+
d2 = D2.diag
334+
if iszero(alpha)
335+
_rmul_or_fill!(out.diag, beta)
308336
else
309-
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
337+
if iszero(beta)
338+
@inbounds @simd for i in eachindex(out.diag)
339+
out.diag[i] = d1[i] * d2[i] * alpha
340+
end
341+
else
342+
@inbounds @simd for i in eachindex(out.diag)
343+
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
344+
end
345+
end
346+
end
347+
return out
348+
end
349+
@inline function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta)
350+
require_one_based_indexing(out)
351+
mA = size(D1, 1)
352+
d1 = D1.diag
353+
d2 = D2.diag
354+
_rmul_or_fill!(out, beta)
355+
if !iszero(alpha)
356+
@inbounds @simd for i in 1:mA
357+
out[i,i] += d1[i] * d2[i] * alpha
358+
end
310359
end
311360
return out
312361
end
313-
314-
# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
315-
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
316-
mul!(out, D1, D2, alpha, beta)
317362

318363
@inline function _muldiag!(out, A, B, alpha, beta)
319364
_muldiag_size_check(out, A, B)
@@ -340,24 +385,8 @@ end
340385
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
341386
_muldiag!(C, Da, Db, alpha, beta)
342387

343-
function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
344-
_muldiag_size_check(C, Da, Db)
345-
require_one_based_indexing(C)
346-
mA = size(Da, 1)
347-
da = Da.diag
348-
db = Db.diag
349-
_rmul_or_fill!(C, beta)
350-
if iszero(beta)
351-
@inbounds @simd for i in 1:mA
352-
C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha
353-
end
354-
else
355-
@inbounds @simd for i in 1:mA
356-
C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha
357-
end
358-
end
359-
return C
360-
end
388+
mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
389+
_muldiag!(C, Da, Db, alpha, beta)
361390

362391
_init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
363392
(_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# inside this function.
99
function *end
1010
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
11-
iszero(beta::Number) ? false :
12-
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
11+
iszero(beta::Number) ? false : broadcasted(*, out, beta)
1312

1413
"""
1514
MulAddMul(alpha, beta)

0 commit comments

Comments
 (0)