Skip to content
This repository was archived by the owner on May 23, 2022. It is now read-only.

Commit 41d4f18

Browse files
Merge pull request #51 from JuliaML/cl/obs
add getobs and nobs common implementations
2 parents 534f8d0 + e7e8f15 commit 41d4f18

File tree

3 files changed

+126
-35
lines changed

3 files changed

+126
-35
lines changed

src/LearnBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module LearnBase
22

3+
import StatsBase
34
using StatsBase: nobs
45

56
# AGGREGATION MODES

src/observation.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ Specify the default observation dimension for `data`.
55
Falls back to `nothing` when an observation dimension is undefined.
66
77
By default, the following implementations are provided:
8-
- `default_obsdim(A::AbstractArray) = ndims(A)`
9-
- `default_obsdim(tup::Tuple) = map(default_obsdim, tup)`
8+
```julia
9+
default_obsdim(x::nothing) = nothing
10+
default_obsdim(x::AbstractArray) = ndims(x)
11+
````
1012
"""
1113
default_obsdim(data) = nothing
12-
default_obsdim(A::AbstractArray) = ndims(A)
13-
default_obsdim(tup::Tuple) = map(default_obsdim, tup)
14+
default_obsdim(A::AbstractArray{T,N}) where {T,N} = N
1415

1516
"""
1617
getobs(data, idx; obsdim = default_obsdim(data))
@@ -74,6 +75,24 @@ getobs(dataset, 1:2) # -> (X[:,1:2], Y[1:2])
7475
"""
7576
function getobs end
7677

78+
function getobs(data::AbstractArray{T,N}, idx; obsdim::Union{Int,Nothing}=nothing) where {T, N}
79+
od = obsdim === nothing ? default_obsdim(data) : obsdim
80+
_idx = ntuple(i -> i == od ? idx : Colon(), N)
81+
data[_idx...]
82+
end
83+
84+
function getobs(data::Union{Tuple, NamedTuple}, i; obsdim::Union{Int,Nothing}=default_obsdim(data))
85+
# We don't force users to handle the obsdim keyword if not necessary.
86+
fobs = obsdim === nothing ? Base.Fix2(getobs, i) : x -> getobs(x, i; obsdim=obsdim)
87+
map(fobs, data)
88+
end
89+
90+
function getobs(data::D, i; obsdim::Union{Int,Nothing}=default_obsdim(data)) where {D<:AbstractDict}
91+
fobs = obsdim === nothing ? Base.Fix2(getobs, i) : x -> getobs(x, i; obsdim=obsdim)
92+
# Cannot return D because the value type can change
93+
Dict(k => fobs(v) for (k, v) in pairs(data))
94+
end
95+
7796
"""
7897
getobs!(buffer, data, idx; obsdim = default_obsdim(obsdim))
7998
@@ -123,6 +142,37 @@ to disptach on which dimension of `data` denotes the observations.
123142
"""
124143
function datasubset end
125144

145+
146+
# We don't own nobs but pirate it for basic types
147+
"""
148+
nobs(data; [obsdim])
149+
150+
Return the number of observations in the dataset `data`.
151+
152+
If it makes sense for the type of `data`, `obsdim` can be used
153+
to indicate which dimension of `data` denotes the observations.
154+
See [`default_obsdim`](@ref) for defining a default dimension.
155+
"""
156+
function StatsBase.nobs(data::AbstractArray; obsdim::Union{Int,Nothing}=nothing)
157+
od = obsdim === nothing ? default_obsdim(data) : obsdim
158+
size(data, od)
159+
end
160+
161+
function StatsBase.nobs(data::Union{Tuple, NamedTuple, AbstractDict}; obsdim::Union{Int,Nothing} = default_obsdim(data))
162+
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
163+
164+
# We don't force users to handle the obsdim
165+
# keyword if not necessary.
166+
fnobs = obsdim === nothing ? nobs : x -> nobs(x; obsdim=obsdim)
167+
168+
n = fnobs(data[first(keys(data))])
169+
for i in keys(data)
170+
ni = fnobs(data[i])
171+
n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. "))
172+
end
173+
return n
174+
end
175+
126176
# todeprecate
127177
function target end
128178
function gettarget end

test/observation.jl

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,78 @@ using LearnBase: getobs, nobs, default_obsdim
55
@test typeof(LearnBase.gettargets) <: Function
66
@test typeof(LearnBase.datasubset) <: Function
77

8-
@testset "getobs" begin
9-
10-
function LearnBase.getobs(x::AbstractArray{T,N}, idx; obsdim=default_obsdim(x)) where {T,N}
11-
_idx = ntuple(i-> i == obsdim ? idx : Colon(), N)
12-
return x[_idx...]
8+
@testset "getobs and nobs" begin
9+
10+
@testset "array" begin
11+
a = rand(2,3)
12+
@test nobs(a) == 3
13+
@test @inferred getobs(a, 1) == a[:,1]
14+
@test @inferred getobs(a, 2) == a[:,2]
15+
@test @inferred getobs(a, 1:2) == a[:,1:2]
16+
@test @inferred getobs(a, 1, obsdim=1) == a[1,:]
17+
@test @inferred getobs(a, 2, obsdim=1) == a[2,:]
18+
@test @inferred getobs(a, 2, obsdim=nothing) a[:,2]
19+
end
20+
21+
@testset "tuple" begin
22+
# A dataset with 3 observations, each with 2 input features
23+
X, Y = rand(2, 3), rand(3)
24+
dataset = (X, Y)
25+
@test nobs(dataset) == 3
26+
if VERSION >= v"1.6"
27+
o = @inferred getobs(dataset, 2)
28+
else
29+
o = getobs(dataset, 2)
30+
end
31+
@test o[1] == X[:,2]
32+
@test o[2] == Y[2]
33+
34+
if VERSION >= v"1.6"
35+
o = @inferred getobs(dataset, 1:2)
36+
else
37+
o = getobs(dataset, 1:2)
38+
end
39+
40+
@test o[1] == X[:,1:2]
41+
@test o[2] == Y[1:2]
42+
end
43+
44+
45+
@testset "named tuple" begin
46+
X, Y = rand(2, 3), rand(3)
47+
dataset = (x=X, y=Y)
48+
@test nobs(dataset) == 3
49+
if VERSION >= v"1.6"
50+
o = @inferred getobs(dataset, 2)
51+
else
52+
o = getobs(dataset, 2)
53+
end
54+
@test o.x == X[:,2]
55+
@test o.y == Y[2]
56+
57+
if VERSION >= v"1.6"
58+
o = @inferred getobs(dataset, 1:2)
59+
else
60+
o = getobs(dataset, 1:2)
61+
end
62+
@test o.x == X[:,1:2]
63+
@test o.y == Y[1:2]
64+
end
65+
66+
@testset "dict" begin
67+
X, Y = rand(2, 3), rand(3)
68+
dataset = Dict("X" => X, "Y" => Y)
69+
@test nobs(dataset) == 3
70+
71+
# o = @inferred getobs(dataset, 2) # not inferred
72+
o = getobs(dataset, 2)
73+
@test o["X"] == X[:,2]
74+
@test o["Y"] == Y[2]
75+
76+
o = getobs(dataset, 1:2)
77+
@test o["X"] == X[:,1:2]
78+
@test o["Y"] == Y[1:2]
1379
end
14-
LearnBase.nobs(x::AbstractArray; obsdim=default_obsdim(x)) = size(x, obsdim)
15-
16-
a = rand(2,3)
17-
@test nobs(a) == 3
18-
@test getobs(a, 1) a[:,1]
19-
@test getobs(a, 2) a[:,2]
20-
@test getobs(a, 1, obsdim=1) a[1,:]
21-
@test getobs(a, 2, obsdim=1) a[2,:]
22-
23-
# Here we use Ref to protect idx against broadcasting
24-
LearnBase.getobs(t::Tuple, idx) = getobs.(t, Ref(idx))
25-
# Assume all elements have the same nummber of observations.
26-
# It would be safer to check explicitely though.
27-
LearnBase.nobs(t::Tuple) = nobs(t[1])
28-
29-
# A dataset with 3 observations, each with 2 input features
30-
X, Y = rand(2, 3), rand(3)
31-
dataset = (X, Y)
32-
33-
o = getobs(dataset, 2) # -> (X[:,2], Y[2])
34-
@test o[1] X[:,2]
35-
@test o[2] == Y[2]
36-
37-
o = getobs(dataset, 1:2) # -> (X[:,1:2], Y[1:2])
38-
@test o[1] X[:,1:2]
39-
@test o[2] == Y[1:2]
4080
end
4181

4282

0 commit comments

Comments
 (0)