|
249 | 249 | (*)(D::Diagonal, A::AbstractMatrix) =
|
250 | 250 | mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
|
251 | 251 |
|
252 |
| -rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D) |
253 |
| -lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B) |
| 252 | +rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) |
| 253 | +lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) |
254 | 254 |
|
255 | 255 | #TODO: It seems better to call (D' * adjA')' directly?
|
256 | 256 | function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal)
|
@@ -285,35 +285,80 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
|
285 | 285 | end
|
286 | 286 |
|
287 | 287 | @inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
|
288 |
| - if iszero(beta) |
289 |
| - out .= (D.diag .* B) .*ₛ alpha |
| 288 | + require_one_based_indexing(out) |
| 289 | + if iszero(alpha) |
| 290 | + _rmul_or_fill!(out, beta) |
290 | 291 | else
|
291 |
| - out .= (D.diag .* B) .*ₛ alpha .+ out .* beta |
| 292 | + if iszero(beta) |
| 293 | + @inbounds for j in axes(B, 2) |
| 294 | + @simd for i in axes(B, 1) |
| 295 | + out[i,j] = D.diag[i] * B[i,j] * alpha |
| 296 | + end |
| 297 | + end |
| 298 | + else |
| 299 | + @inbounds for j in axes(B, 2) |
| 300 | + @simd for i in axes(B, 1) |
| 301 | + out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta |
| 302 | + end |
| 303 | + end |
| 304 | + end |
292 | 305 | end
|
293 | 306 | return out
|
294 | 307 | end
|
295 |
| - |
296 | 308 | @inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
|
297 |
| - if iszero(beta) |
298 |
| - out .= (A .* permutedims(D.diag)) .*ₛ alpha |
| 309 | + require_one_based_indexing(out) |
| 310 | + if iszero(alpha) |
| 311 | + _rmul_or_fill!(out, beta) |
299 | 312 | else
|
300 |
| - out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta |
| 313 | + if iszero(beta) |
| 314 | + @inbounds for j in axes(A, 2) |
| 315 | + dja = D.diag[j] * alpha |
| 316 | + @simd for i in axes(A, 1) |
| 317 | + out[i,j] = A[i,j] * dja |
| 318 | + end |
| 319 | + end |
| 320 | + else |
| 321 | + @inbounds for j in axes(A, 2) |
| 322 | + dja = D.diag[j] * alpha |
| 323 | + @simd for i in axes(A, 1) |
| 324 | + out[i,j] = A[i,j] * dja + out[i,j] * beta |
| 325 | + end |
| 326 | + end |
| 327 | + end |
301 | 328 | end
|
302 | 329 | return out
|
303 | 330 | end
|
304 |
| - |
305 | 331 | @inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
|
306 |
| - if iszero(beta) |
307 |
| - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha |
| 332 | + d1 = D1.diag |
| 333 | + d2 = D2.diag |
| 334 | + if iszero(alpha) |
| 335 | + _rmul_or_fill!(out.diag, beta) |
308 | 336 | else
|
309 |
| - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta |
| 337 | + if iszero(beta) |
| 338 | + @inbounds @simd for i in eachindex(out.diag) |
| 339 | + out.diag[i] = d1[i] * d2[i] * alpha |
| 340 | + end |
| 341 | + else |
| 342 | + @inbounds @simd for i in eachindex(out.diag) |
| 343 | + out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta |
| 344 | + end |
| 345 | + end |
| 346 | + end |
| 347 | + return out |
| 348 | +end |
| 349 | +@inline function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) |
| 350 | + require_one_based_indexing(out) |
| 351 | + mA = size(D1, 1) |
| 352 | + d1 = D1.diag |
| 353 | + d2 = D2.diag |
| 354 | + _rmul_or_fill!(out, beta) |
| 355 | + if !iszero(alpha) |
| 356 | + @inbounds @simd for i in 1:mA |
| 357 | + out[i,i] += d1[i] * d2[i] * alpha |
| 358 | + end |
310 | 359 | end
|
311 | 360 | return out
|
312 | 361 | end
|
313 |
| - |
314 |
| -# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments |
315 |
| -@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) = |
316 |
| - mul!(out, D1, D2, alpha, beta) |
317 | 362 |
|
318 | 363 | @inline function _muldiag!(out, A, B, alpha, beta)
|
319 | 364 | _muldiag_size_check(out, A, B)
|
|
340 | 385 | @inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
|
341 | 386 | _muldiag!(C, Da, Db, alpha, beta)
|
342 | 387 |
|
343 |
| -function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) |
344 |
| - _muldiag_size_check(C, Da, Db) |
345 |
| - require_one_based_indexing(C) |
346 |
| - mA = size(Da, 1) |
347 |
| - da = Da.diag |
348 |
| - db = Db.diag |
349 |
| - _rmul_or_fill!(C, beta) |
350 |
| - if iszero(beta) |
351 |
| - @inbounds @simd for i in 1:mA |
352 |
| - C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha |
353 |
| - end |
354 |
| - else |
355 |
| - @inbounds @simd for i in 1:mA |
356 |
| - C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha |
357 |
| - end |
358 |
| - end |
359 |
| - return C |
360 |
| -end |
| 388 | +mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = |
| 389 | + _muldiag!(C, Da, Db, alpha, beta) |
361 | 390 |
|
362 | 391 | _init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
|
363 | 392 | (_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))
|
|
0 commit comments