Skip to content

test_rrule fails to handle QRCompactWY return types #184

Open
@rkube

Description

@rkube

The QR-factorization returns a QRCompactWY, which describes two matrices. When testing a custom pullback for the qr-factorization, test_rrule calls length on that type which is not well defined.

MWE:

using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences

using ChainRulesTestUtils

ChainRulesCore.debug_mode() = true
Random.seed!(1234);

function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} 
    QR = qr(A)
    m, n = size(A)
    function qr_pullback(Ȳ::Tangent)
        # For square (m=n) or tall and skinny (m >= n), use the rule derived by 
        # Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
        #   
        # Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
        #   
        # where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
        # copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
        #   
        # This code is re-used in the wide case and we put it in a separate function.

        function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
            M = R̄*R' - Q'*Q̄
            # M <- copyltu(M)
            M = triu(M) + transpose(triu(M,1))
            Ā = (Q̄ + Q * M) / R'
        end 

        # For the wide (m < n) case, we implement the rule derived by
        # Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
        #   
        # Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
        # where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
        #  R = (U,V). Both X and U are full rank square matrices.
        #   
        # See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
        # And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp 
        Q̄ = Ȳ.factors
        R̄ = Ȳ.T 
        Q = QR.Q
        R = QR.R
        if m ≥ n 
            Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)] 
            Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
        else
            # partition A = [X | Y]
            # X = A[1:m, 1:m]
            Y = A[1:m, m + 1:end]
    
            # partition R = [U | V], and we don't need V
            U = R[1:m, 1:m]
            if R̄ isa ChainRules.AbstractZero
                V̄ = zeros(size(Y))
                Q̄_prime = zeros(size(Q))
                Ū = R̄ 
            else
                # partition R̄ = [Ū | V̄]
                Ū = R̄[1:m, 1:m]
                V̄ = R̄[1:m, m + 1:end]
                Q̄_prime = Y * V̄'
            end 

            Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ 

            X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
            Ȳ = Q * V̄ 
            # partition Ā = [X̄ | Ȳ]
            Ā = [X̄ Ȳ]
        end 
        return (NoTangent(), Ā)
    end 
    return QR, qr_pullback
end


function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) 
    function getproperty_qr_pullback(Ȳ)
        # The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see 
        # R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
        # Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
        # in the Tangent object.
        ∂factors = if d === :Q
            Ȳ
        else
            nothing
        end

        ∂T = if d === :R
            Ȳ
        else
            nothing
        end

        ∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
        return (NoTangent(), ∂F)
    end

    return getproperty(F, d), getproperty_qr_pullback
end

V = randn((4,4))
test_rrule(qr, V)

Fails with:

  Got exception outside of a @test
  MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
  Closest candidates are:
    length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
    length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
    length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195

This is because length is not defined for ::LinearAlgebra.QRCompactWY:

julia> typeof(qr(V))
LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}}
julia> length(qr(V))
ERROR: MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
  length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[9]:1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions