Skip to content

Commit c646b5d

Browse files
authored
Specialize findmax/findmin on SparseVector, fixes #42823 (#42825)
1 parent 2179795 commit c646b5d

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

stdlib/SparseArrays/src/sparsematrix.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,8 +2172,10 @@ end
21722172
_isless_fm(a, b) = b == b && ( a != a || isless(a, b) )
21732173
_isgreater_fm(a, b) = b == b && ( a != a || isless(b, a) )
21742174

2175-
findmin(A::AbstractSparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isless_fm, A, region, Tv)
2176-
findmax(A::AbstractSparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isgreater_fm, A, region, Tv)
2175+
findmin(A::AbstractSparseMatrixCSC{Tv}, region::Union{Integer,Tuple{Integer},NTuple{2,Integer}}) where {Tv} =
2176+
_findr(_isless_fm, A, region, Tv)
2177+
findmax(A::AbstractSparseMatrixCSC{Tv}, region::Union{Integer,Tuple{Integer},NTuple{2,Integer}}) where {Tv} =
2178+
_findr(_isgreater_fm, A, region, Tv)
21772179
findmin(A::AbstractSparseMatrixCSC) = (r=findmin(A,(1,2)); (r[1][1], r[2][1]))
21782180
findmax(A::AbstractSparseMatrixCSC) = (r=findmax(A,(1,2)); (r[1][1], r[2][1]))
21792181

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,27 @@ end
14111411

14121412
minimum(x::AbstractSparseVector) = minimum(identity, x)
14131413

1414+
for (fun, comp, word) in ((:findmin, :(<), "minimum"), (:findmax, :(>), "maximum"))
1415+
@eval function $fun(f, x::AbstractSparseVector{T}) where {T}
1416+
n = length(x)
1417+
n > 0 || throw(ArgumentError($word * " over empty array is not allowed"))
1418+
nzvals = nonzeros(x)
1419+
m = length(nzvals)
1420+
m == 0 && return zero(T), firstindex(x)
1421+
val, index = $fun(f, nzvals)
1422+
m == n && return val, index
1423+
nzinds = nonzeroinds(x)
1424+
zeroval = f(zero(T))
1425+
$comp(val, zeroval) && return val, nzinds[index]
1426+
# we need to find the first zero, which could be stored or implicit
1427+
# we try to avoid findfirst(iszero, x)
1428+
sindex = findfirst(iszero, nzvals) # first stored zero, if any
1429+
zindex = findfirst(i -> i < nzinds[i], eachindex(nzinds)) # first non-stored zero
1430+
index = isnothing(sindex) ? zindex : min(sindex, zindex)
1431+
return zeroval, index
1432+
end
1433+
end
1434+
14141435
norm(x::SparseVectorUnion, p::Real=2) = norm(nonzeros(x), p)
14151436

14161437
### linalg.jl

stdlib/SparseArrays/test/sparsevector.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,10 +816,14 @@ end
816816
@test norm(x, Inf) == 3.5
817817
end
818818

819-
@testset "maximum, minimum" begin
819+
@testset "maximum, minimum, findmax, findmin" begin
820820
let x = spv_x1
821821
@test maximum(x) == 3.5
822+
@test findmax(x) == findmax(Vector(x)) == (3.5, 6)
823+
@test findmax(x -> -x, x) == findmax(-x) == (0.75, 5)
822824
@test minimum(x) == -0.75
825+
@test findmin(x) == findmin(Vector(x)) == (-0.75, 5)
826+
@test findmin(x -> -x, x) == findmin(-x) == (-3.5, 6)
823827
@test maximum(abs, x) == 3.5
824828
@test minimum(abs, x) == 0.0
825829
@test @inferred(minimum(t -> true, x)) === true
@@ -832,21 +836,51 @@ end
832836

833837
let x = abs.(spv_x1)
834838
@test maximum(x) == 3.5
839+
@test findmax(x) == findmax(Vector(x)) == (3.5, 6)
840+
@test findmax(abs2, x) == findmax(abs2.(x)) == findmax(Vector(abs2.(x)))
835841
@test minimum(x) == 0.0
842+
@test findmin(x) == findmin(Vector(x)) == (0.0, 1)
843+
@test findmin(abs2, x) == findmin(abs2.(x)) == findmin(Vector(abs2.(x)))
836844
end
837845

838846
let x = -abs.(spv_x1)
839847
@test maximum(x) == 0.0
848+
@test findmax(x) == findmax(Vector(x)) == (0.0, 1)
840849
@test minimum(x) == -3.5
850+
@test findmin(x) == findmin(Vector(x)) == (-3.5, 6)
841851
end
842852

843853
let x = SparseVector(3, [1, 2, 3], [-4.5, 2.5, 3.5])
844854
@test maximum(x) == 3.5
855+
@test findmax(x) == findmax(Vector(x)) == (3.5, 3)
845856
@test minimum(x) == -4.5
857+
@test findmin(x) == findmin(Vector(x)) == (-4.5, 1)
846858
@test maximum(abs, x) == 4.5
847859
@test minimum(abs, x) == 2.5
848860
end
849861

862+
let x = SparseVector(3, [1, 2, 3], [4.5, 0.0, 3.5])
863+
@test minimum(x) == 0.0
864+
@test findmin(x) == findmin(Vector(x)) == (0.0, 2)
865+
end
866+
867+
let x = SparseVector(3, [1, 2, 3], [-4.5, 0.0, -3.5])
868+
@test maximum(x) == 0.0
869+
@test findmax(x) == findmax(Vector(x)) == (0.0, 2)
870+
end
871+
872+
for i in (2, 3)
873+
let x = SparseVector(4, [1, i, 4], [4.5, 0.0, 3.5])
874+
@test minimum(x) == 0.0
875+
@test findmin(x) == findmin(Vector(x)) == (0.0, 2)
876+
end
877+
878+
let x = SparseVector(4, [1, i, 4], [-4.5, 0.0, -3.5])
879+
@test maximum(x) == 0.0
880+
@test findmax(x) == findmax(Vector(x)) == (0.0, 2)
881+
end
882+
end
883+
850884
let x = spzeros(Float64, 8)
851885
@test maximum(x) == 0.0
852886
@test minimum(x) == 0.0
@@ -861,6 +895,8 @@ end
861895
let x = spzeros(Float64, 0)
862896
@test_throws ArgumentError minimum(t -> true, x)
863897
@test_throws ArgumentError maximum(t -> true, x)
898+
@test_throws ArgumentError findmin(x)
899+
@test_throws ArgumentError findmax(x)
864900
end
865901
end
866902

0 commit comments

Comments
 (0)