@@ -28,6 +28,35 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
28
28
class TreeSplitUtilsSuite
29
29
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
30
30
31
+ /**
32
+ * Iterate over feature values and labels for a specific (node, feature), updating stats
33
+ * aggregator for the current node.
34
+ */
35
+ private [impl] def updateAggregator (
36
+ statsAggregator : DTStatsAggregator ,
37
+ featureIndex : Int ,
38
+ values : Array [Int ],
39
+ indices : Array [Int ],
40
+ instanceWeights : Array [Double ],
41
+ labels : Array [Double ],
42
+ from : Int ,
43
+ to : Int ,
44
+ featureIndexIdx : Int ,
45
+ featureSplits : Array [Split ]): Unit = {
46
+ val metadata = statsAggregator.metadata
47
+ from.until(to).foreach { idx =>
48
+ val rowIndex = indices(idx)
49
+ if (metadata.isUnordered(featureIndex)) {
50
+ AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
51
+ featureIndex = featureIndex, featureIndexIdx, featureSplits,
52
+ instanceWeight = instanceWeights(rowIndex))
53
+ } else {
54
+ AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
55
+ featureIndexIdx, instanceWeight = instanceWeights(rowIndex))
56
+ }
57
+ }
58
+ }
59
+
31
60
/**
32
61
* Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
33
62
* with the data from the specified training points.
@@ -40,12 +69,13 @@ class TreeSplitUtilsSuite
40
69
labels : Array [Double ],
41
70
featureSplits : Array [Split ]): DTStatsAggregator = {
42
71
72
+ val featureIndex = 0
43
73
val statsAggregator = new DTStatsAggregator (metadata, featureSubset = None )
44
74
val instanceWeights = Array .fill[Double ](values.length)(1.0 )
45
75
val indices = values.indices.toArray
46
76
AggUpdateUtils .updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
47
- LocalDecisionTree . updateAggregator(statsAggregator, col , indices, instanceWeights, labels,
48
- from, to, col. featureIndex, featureSplits)
77
+ updateAggregator(statsAggregator, featureIndex = 0 , values , indices, instanceWeights, labels,
78
+ from, to, featureIndex, featureSplits)
49
79
statsAggregator
50
80
}
51
81
@@ -73,34 +103,34 @@ class TreeSplitUtilsSuite
73
103
test(" chooseSplit: choose correct type of split (continuous split)" ) {
74
104
// Construct (binned) continuous data
75
105
val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
76
- val col = FeatureColumn (featureIndex = 0 , values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 ))
106
+ val values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 )
107
+ val featureIndex = 0
77
108
// Get an array of continuous splits corresponding to values in our binned data
78
109
val splits = TreeTests .getContinuousSplits(1 .to(8 ).toArray, featureIndex = 0 )
79
110
// Construct DTStatsAggregator, compute sufficient stats
80
111
val metadata = TreeTests .getMetadata(numExamples = 7 ,
81
112
numFeatures = 1 , numClasses = 2 , Map .empty)
82
- val statsAggregator = getAggregator(metadata, col , from = 1 , to = 4 , labels, splits)
113
+ val statsAggregator = getAggregator(metadata, values , from = 1 , to = 4 , labels, splits)
83
114
// Choose split, check that it's a valid ContinuousSplit
84
- val (split1, stats1) = SplitUtils .chooseSplit(statsAggregator, col. featureIndex,
85
- col.featureIndex, splits)
115
+ val (split1, stats1) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
116
+ splits)
86
117
assert(stats1.valid && split1.isInstanceOf [ContinuousSplit ])
87
118
}
88
119
89
120
test(" chooseSplit: choose correct type of split (categorical split)" ) {
90
121
// Construct categorical data
91
122
val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
92
- val featureIndex = 0
93
123
val featureArity = 3
94
124
val values = Array (0 , 0 , 1 , 1 , 1 , 2 , 2 )
95
- val col = FeatureColumn (featureIndex, values)
125
+ val featureIndex = 0
96
126
// Construct DTStatsAggregator, compute sufficient stats
97
127
val metadata = TreeTests .getMetadata(numExamples = 7 ,
98
128
numFeatures = 1 , numClasses = 2 , Map (featureIndex -> featureArity))
99
129
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
100
- val statsAggregator = getAggregator(metadata, col , from = 1 , to = 4 , labels, splits)
130
+ val statsAggregator = getAggregator(metadata, values , from = 1 , to = 4 , labels, splits)
101
131
// Choose split, check that it's a valid categorical split
102
132
val (split2, stats2) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
103
- featureIndex = col. featureIndex, featureIndexIdx = col. featureIndex,
133
+ featureIndex = featureIndex, featureIndexIdx = featureIndex,
104
134
featureSplits = splits)
105
135
assert(stats2.valid && split2.isInstanceOf [CategoricalSplit ])
106
136
}
@@ -117,16 +147,14 @@ class TreeSplitUtilsSuite
117
147
// Construct FeatureVector to store categorical data
118
148
val featureArity = values.max + 1
119
149
val arityMap = Map [Int , Int ](featureIndex -> featureArity)
120
- val col = FeatureColumn (featureIndex = 0 , values = values)
121
150
// Construct DTStatsAggregator, compute sufficient stats
122
151
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
123
152
numClasses = 2 , arityMap, unorderedFeatures = Some (Set .empty))
124
- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
153
+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
125
154
labels, featureSplits = Array .empty)
126
155
// Choose split
127
156
val (split, stats) =
128
- SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex,
129
- col.featureIndex)
157
+ SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
130
158
// Verify that split has the expected left-side/right-side categories
131
159
val expectedRightCategories = Range (0 , featureArity)
132
160
.filter(c => ! expectedLeftCategories.contains(c)).map(_.toDouble).toArray
@@ -156,15 +184,14 @@ class TreeSplitUtilsSuite
156
184
val values = Array (0 , 0 , 1 , 2 , 2 , 2 , 2 )
157
185
val featureArity = values.max + 1
158
186
val labels = Array (1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 )
159
- val col = FeatureColumn (featureIndex, values)
160
187
// Construct DTStatsAggregator, compute sufficient stats
161
188
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
162
189
numClasses = 2 , Map (featureIndex -> featureArity), unorderedFeatures = Some (Set .empty))
163
- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
190
+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
164
191
labels, featureSplits = Array .empty)
165
192
// Choose split, verify that it's invalid
166
- val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, col. featureIndex,
167
- col. featureIndex)
193
+ val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
194
+ featureIndex)
168
195
assert(! stats.valid)
169
196
}
170
197
@@ -177,17 +204,16 @@ class TreeSplitUtilsSuite
177
204
val values = Array (1 , 1 , 0 , 2 , 2 )
178
205
val featureArity = values.max + 1
179
206
val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 )
180
- val col = FeatureColumn (featureIndex, values)
181
207
// Construct DTStatsAggregator, compute sufficient stats
182
208
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
183
209
numClasses = 3 , Map (featureIndex -> featureArity))
184
210
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
185
- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
211
+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
186
212
labels, splits)
187
213
// Choose split
188
214
val (split, stats) =
189
- SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, col. featureIndex,
190
- col.featureIndex, splits)
215
+ SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
216
+ splits)
191
217
// Verify that split has the expected left-side/right-side categories
192
218
split match {
193
219
case s : CategoricalSplit =>
@@ -208,12 +234,12 @@ class TreeSplitUtilsSuite
208
234
val featureArity = 4
209
235
val values = Array (3 , 1 , 0 , 2 , 2 )
210
236
val labels = Array (1.0 , 1.0 , 1.0 , 1.0 , 1.0 )
211
- val col = FeatureColumn (featureIndex, values)
212
237
// Construct DTStatsAggregator, compute sufficient stats
213
238
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
214
239
numClasses = 2 , Map (featureIndex -> featureArity))
215
240
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
216
- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
241
+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
242
+ splits)
217
243
// Choose split, verify that it's invalid
218
244
val (_, stats) = SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
219
245
featureIndex, splits)
@@ -226,13 +252,12 @@ class TreeSplitUtilsSuite
226
252
val thresholds = Array (0 , 1 , 2 , 3 )
227
253
val values = thresholds.indices.toArray
228
254
val labels = Array (0.0 , 0.0 , 1.0 , 1.0 )
229
- val col = FeatureColumn (featureIndex = featureIndex, values = values)
230
-
231
255
// Construct DTStatsAggregator, compute sufficient stats
232
256
val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
233
257
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
234
258
numClasses = 2 , Map .empty)
235
- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
259
+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
260
+ splits)
236
261
237
262
// Choose split, verify that it has expected threshold
238
263
val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
@@ -256,12 +281,12 @@ class TreeSplitUtilsSuite
256
281
val thresholds = Array (0 , 1 , 2 , 3 )
257
282
val values = thresholds.indices.toArray
258
283
val labels = Array (0.0 , 0.0 , 0.0 , 0.0 , 0.0 )
259
- val col = FeatureColumn (featureIndex = featureIndex, values = values)
260
284
// Construct DTStatsAggregator, compute sufficient stats
261
285
val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
262
286
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
263
287
numClasses = 2 , Map .empty[Int , Int ])
264
- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
288
+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
289
+ splits)
265
290
// Choose split, verify that it's invalid
266
291
val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
267
292
featureIndex, splits)
0 commit comments