Skip to content

Commit e6a5d2a

Browse files
dkarraschKristofferC
authored andcommitted
Fix kron with Diagonal (#40509)
(cherry picked from commit a20e547)
1 parent 179cc16 commit e6a5d2a

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -500,14 +500,13 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
500500
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)
501501

502502

503-
@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
504-
fill!(C, zero(T))
503+
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
505504
valA = A.diag; nA = length(valA)
506505
valB = B.diag; nB = length(valB)
507506
nC = checksquare(C)
508507
@boundscheck nC == nA*nB ||
509508
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
510-
509+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
511510
@inbounds for i = 1:nA, j = 1:nB
512511
idx = (i-1)*nB+j
513512
C[idx, idx] = valA[i] * valB[j]
@@ -525,9 +524,12 @@ end
525524

526525
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
527526
Base.require_one_based_indexing(B)
528-
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
527+
(mA, nA) = size(A)
528+
(mB, nB) = size(B)
529+
(mC, nC) = size(C)
529530
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
530531
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
532+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
531533
m = 1
532534
@inbounds for j = 1:nA
533535
A_jj = A[j,j]
@@ -545,9 +547,12 @@ end
545547

546548
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
547549
require_one_based_indexing(A)
548-
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
550+
(mA, nA) = size(A)
551+
(mB, nB) = size(B)
552+
(mC, nC) = size(C)
549553
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
550554
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
555+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
551556
m = 1
552557
@inbounds for j = 1:nA
553558
for l = 1:mB
@@ -563,18 +568,6 @@ end
563568
return C
564569
end
565570

566-
function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
567-
(mA, nA) = size(A); (mB, nB) = size(B)
568-
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
569-
return @inbounds kron!(R, A, B)
570-
end
571-
572-
function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
573-
(mA, nA) = size(A); (mB, nB) = size(B)
574-
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
575-
return @inbounds kron!(R, A, B)
576-
end
577-
578571
conj(D::Diagonal) = Diagonal(conj(D.diag))
579572
transpose(D::Diagonal{<:Number}) = D
580573
transpose(D::Diagonal) = Diagonal(transpose.(D.diag))

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ Random.seed!(1)
295295
M4 = rand(elty, n÷2, n÷2)
296296
@test kron(D3, M4) kron(DM3, M4)
297297
@test kron(M4, D3) kron(M4, DM3)
298+
X = [ones(1,1) for i in 1:2, j in 1:2]
299+
@test kron(I(2), X)[1,3] == zeros(1,1)
300+
X = [ones(2,2) for i in 1:2, j in 1:2]
301+
@test kron(I(2), X)[1,3] == zeros(2,2)
298302
end
299303
@testset "iszero, isone, triu, tril" begin
300304
Dzero = Diagonal(zeros(elty, 10))

0 commit comments

Comments
 (0)