Skip to content

[WIP] faster reductions #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 26, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,23 @@ end
end
end

@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
_mapreduce(f, op, Val(D), nt, sz, a)
@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
# Body of this function is split because constant propagation (at least
# as of Julia 1.2) can't always correctly propagate here and
# as a result the function is not type stable and very slow.
# This makes it at least fast for three dimensions but people should use
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
if D == 1
return _mapreduce(f, op, Val(1), nt, sz, a)
elseif D == 2
return _mapreduce(f, op, Val(2), nt, sz, a)
elseif D == 3
return _mapreduce(f, op, Val(3), nt, sz, a)
else
return _mapreduce(f, op, Val(D), nt, sz, a)
end
end


@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Expand Down Expand Up @@ -161,7 +174,9 @@ end
## reduce ##
############

@inline reduce(op, a::StaticArray; kw...) = mapreduce(identity, op, a; kw...)
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)

@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)

#######################
## related functions ##
Expand All @@ -186,38 +201,38 @@ end
# TODO: change to use Base.reduce_empty/Base.reduce_first
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)

@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(+, a; dims=dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity

@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)

@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(+, a; dims=dims)
@inline count(f, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, +, a; dims=dims)
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)

@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, &, a; dims=dims, init=true)
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)

@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, |, a; dims=dims, init=false) # (benchmarking needed)
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)

@inline Base.in(x, a::StaticArray) = mapreduce(==(x), |, a, init=false)
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)

_mean_denom(a, dims::Colon) = length(a)
_mean_denom(a, dims::Int) = size(a, dims)
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)

@inline mean(a::StaticArray; dims=:) = sum(a; dims=dims) / _mean_denom(a,dims)
@inline mean(f::Function, a::StaticArray;dims=:) = sum(f, a; dims=dims) / _mean_denom(a,dims)
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)

@inline minimum(a::StaticArray; dims=:) = reduce(min, a; dims=dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = mapreduce(f, min, a; dims=dims)
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)

@inline maximum(a::StaticArray; dims=:) = reduce(max, a; dims=dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = mapreduce(f, max, a; dims=dims)
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)

# Diff is slightly different
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
Expand Down