@@ -30,52 +30,51 @@ class TreeSplitUtilsSuite
30
30
31
31
/**
32
32
* Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
33
- * with the data from the specified training points.
33
+ * with the data from the specified training points. Assumes a feature index of 0 and that
34
+ * all training points have the same weights (1.0).
34
35
*/
35
36
private def getAggregator (
36
37
metadata : DecisionTreeMetadata ,
37
38
values : Array [Int ],
38
- from : Int ,
39
- to : Int ,
40
39
labels : Array [Double ],
41
40
featureSplits : Array [Split ]): DTStatsAggregator = {
42
-
43
- val featureIndex = 0
41
+ // Create stats aggregator
44
42
val statsAggregator = new DTStatsAggregator (metadata, featureSubset = None )
45
- val indices = values.indices.toArray
46
- val instanceWeights = Array .fill[Double ](values.length)(1.0 )
47
43
// Update parent impurity stats
48
- AggUpdateUtils .updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
44
+ val featureIndex = 0
45
+ val instanceWeights = Array .fill[Double ](values.length)(1.0 )
46
+ AggUpdateUtils .updateParentImpurity(statsAggregator, indices = values.indices.toArray,
47
+ from = 0 , to = values.length, instanceWeights, labels)
49
48
// Update current aggregator's impurity stats
50
- from.until(to).foreach { idx =>
51
- val rowIndex = indices(idx)
49
+ values.zip(labels).foreach { case (value : Int , label : Double ) =>
52
50
if (metadata.isUnordered(featureIndex)) {
53
- AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
54
- featureIndex = featureIndex, featureIndexIdx, featureSplits,
55
- instanceWeight = 1.0 )
51
+ AggUpdateUtils .updateUnorderedFeature(statsAggregator, value, label,
52
+ featureIndex = featureIndex, featureIndexIdx = 0 , featureSplits, instanceWeight = 1.0 )
56
53
} else {
57
- AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex) ,
58
- featureIndexIdx, instanceWeight = 1.0 )
54
+ AggUpdateUtils .updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0 ,
55
+ instanceWeight = 1.0 )
59
56
}
60
57
}
61
-
62
- updateAggregator(statsAggregator, featureIndex = 0 , featureIndexIdx = 0 , values, indices,
63
- labels, from, to, featureSplits)
64
58
statsAggregator
65
59
}
66
60
67
- /** Check that left/right impurities match what we'd expect for a split. */
61
+ /**
62
+ * Check that left/right impurities match what we'd expect for a split.
63
+ * @param labels Labels whose impurity information should be reflected in stats
64
+ * @param stats ImpurityStats object containing impurity info for the left/right sides of a split
65
+ */
68
66
private def validateImpurityStats (
69
67
impurity : Impurity ,
70
68
labels : Array [Double ],
71
69
stats : ImpurityStats ,
72
70
expectedLeftStats : Array [Double ],
73
71
expectedRightStats : Array [Double ]): Unit = {
74
- // Verify that impurity stats were computed correctly for split
72
+ // Compute impurity for our data points manually
75
73
val numClasses = (labels.max + 1 ).toInt
76
74
val fullImpurityStatsArray
77
75
= Array .tabulate[Double ](numClasses)((label : Int ) => labels.count(_ == label).toDouble)
78
76
val fullImpurity = Entropy .calculate(fullImpurityStatsArray, labels.length)
77
+ // Verify that impurity stats were computed correctly for split
79
78
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
80
79
assert(stats.impurity === fullImpurity)
81
80
assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
@@ -87,37 +86,37 @@ class TreeSplitUtilsSuite
87
86
88
87
test(" chooseSplit: choose correct type of split (continuous split)" ) {
89
88
// Construct (binned) continuous data
90
- val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
91
- val values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 )
89
+ val labels = Array (0.0 , 0.0 , 1.0 )
90
+ val values = Array (1 , 2 , 3 )
92
91
val featureIndex = 0
93
92
// Get an array of continuous splits corresponding to values in our binned data
94
- val splits = TreeTests .getContinuousSplits(1 .to(8 ).toArray, featureIndex = 0 )
93
+ val splits = TreeTests .getContinuousSplits(thresholds = values.distinct.sorted,
94
+ featureIndex = 0 )
95
95
// Construct DTStatsAggregator, compute sufficient stats
96
- val metadata = TreeTests .getMetadata(numExamples = 7 ,
97
- numFeatures = 1 , numClasses = 2 , Map .empty)
98
- val statsAggregator = getAggregator(metadata, values, from = 1 , to = 4 , labels, splits)
96
+ val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
97
+ numClasses = 2 , Map .empty)
98
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
99
99
// Choose split, check that it's a valid ContinuousSplit
100
- val (split1, stats1 ) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
100
+ val (split, stats ) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
101
101
splits)
102
- assert(stats1 .valid && split1 .isInstanceOf [ContinuousSplit ])
102
+ assert(stats .valid && split .isInstanceOf [ContinuousSplit ])
103
103
}
104
104
105
105
test(" chooseSplit: choose correct type of split (categorical split)" ) {
106
106
// Construct categorical data
107
- val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
107
+ val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 1.0 )
108
108
val featureArity = 3
109
- val values = Array (0 , 0 , 1 , 1 , 1 , 2 , 2 )
109
+ val values = Array (0 , 0 , 1 , 2 , 2 )
110
110
val featureIndex = 0
111
111
// Construct DTStatsAggregator, compute sufficient stats
112
- val metadata = TreeTests .getMetadata(numExamples = 7 ,
113
- numFeatures = 1 , numClasses = 2 , Map (featureIndex -> featureArity))
112
+ val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
113
+ numClasses = 2 , Map (featureIndex -> featureArity))
114
114
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
115
- val statsAggregator = getAggregator(metadata, values, from = 1 , to = 4 , labels, splits)
115
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
116
116
// Choose split, check that it's a valid categorical split
117
- val (split2, stats2) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
118
- featureIndex = featureIndex, featureIndexIdx = featureIndex,
119
- featureSplits = splits)
120
- assert(stats2.valid && split2.isInstanceOf [CategoricalSplit ])
117
+ val (split, stats) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
118
+ featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits)
119
+ assert(stats.valid && split.isInstanceOf [CategoricalSplit ])
121
120
}
122
121
123
122
test(" chooseOrderedCategoricalSplit: basic case" ) {
@@ -128,15 +127,14 @@ class TreeSplitUtilsSuite
128
127
expectedLeftCategories : Array [Double ],
129
128
expectedLeftStats : Array [Double ],
130
129
expectedRightStats : Array [Double ]): Unit = {
130
+ // Set up metadata for ordered categorical feature
131
131
val featureIndex = 0
132
- // Construct FeatureVector to store categorical data
133
132
val featureArity = values.max + 1
134
133
val arityMap = Map [Int , Int ](featureIndex -> featureArity)
135
- // Construct DTStatsAggregator, compute sufficient stats
136
134
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
137
135
numClasses = 2 , arityMap, unorderedFeatures = Some (Set .empty))
138
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
139
- labels, featureSplits = Array .empty)
136
+ // Construct DTStatsAggregator, compute sufficient stats
137
+ val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array .empty)
140
138
// Choose split
141
139
val (split, stats) =
142
140
SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
@@ -155,12 +153,18 @@ class TreeSplitUtilsSuite
155
153
validateImpurityStats(Entropy , labels, stats, expectedLeftStats, expectedRightStats)
156
154
}
157
155
156
+ // Test a single split: The left side of our split should contain the two points with label 0,
157
+ // the left side of our split should contain the five points with label 1
158
158
val values = Array (0 , 0 , 1 , 2 , 2 , 2 , 2 )
159
159
val labels1 = Array (0 , 0 , 1 , 1 , 1 , 1 , 1 ).map(_.toDouble)
160
- testHelper(values, labels1, Array (0.0 ), Array (2.0 , 0.0 ), Array (0.0 , 5.0 ))
160
+ testHelper(values, labels1, expectedLeftCategories = Array (0.0 ),
161
+ expectedLeftStats = Array (2.0 , 0.0 ), expectedRightStats = Array (0.0 , 5.0 ))
161
162
163
+ // Test a single split: The left side of our split should contain the three points with label 0,
164
+ // the left side of our split should contain the four points with label 1
162
165
val labels2 = Array (0 , 0 , 0 , 1 , 1 , 1 , 1 ).map(_.toDouble)
163
- testHelper(values, labels2, Array (0.0 , 1.0 ), Array (3.0 , 0.0 ), Array (0.0 , 4.0 ))
166
+ testHelper(values, labels2, expectedLeftCategories = Array (0.0 , 1.0 ),
167
+ expectedLeftStats = Array (3.0 , 0.0 ), expectedRightStats = Array (0.0 , 4.0 ))
164
168
}
165
169
166
170
test(" chooseOrderedCategoricalSplit: return bad stats if we should not split" ) {
@@ -172,8 +176,7 @@ class TreeSplitUtilsSuite
172
176
// Construct DTStatsAggregator, compute sufficient stats
173
177
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
174
178
numClasses = 2 , Map (featureIndex -> featureArity), unorderedFeatures = Some (Set .empty))
175
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
176
- labels, featureSplits = Array .empty)
179
+ val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array .empty)
177
180
// Choose split, verify that it's invalid
178
181
val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
179
182
featureIndex)
@@ -186,15 +189,15 @@ class TreeSplitUtilsSuite
186
189
// label: 0 --> values: 1
187
190
// label: 1 --> values: 0, 2
188
191
// label: 2 --> values: 2
192
+ // Expected split: feature value 1 on the left, values (0, 2) on the right
189
193
val values = Array (1 , 1 , 0 , 2 , 2 )
190
194
val featureArity = values.max + 1
191
195
val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 )
192
196
// Construct DTStatsAggregator, compute sufficient stats
193
197
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
194
198
numClasses = 3 , Map (featureIndex -> featureArity))
195
199
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
196
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
197
- labels, splits)
200
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
198
201
// Choose split
199
202
val (split, stats) =
200
203
SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
@@ -214,7 +217,7 @@ class TreeSplitUtilsSuite
214
217
}
215
218
216
219
test(" chooseUnorderedCategoricalSplit: return bad stats if we should not split" ) {
217
- // Construct data for unordered categorical feature
220
+ // Construct data for unordered categorical feature; all points have label 1
218
221
val featureIndex = 0
219
222
val featureArity = 4
220
223
val values = Array (3 , 1 , 0 , 2 , 2 )
@@ -223,8 +226,7 @@ class TreeSplitUtilsSuite
223
226
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
224
227
numClasses = 2 , Map (featureIndex -> featureArity))
225
228
val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
226
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
227
- splits)
229
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
228
230
// Choose split, verify that it's invalid
229
231
val (_, stats) = SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
230
232
featureIndex, splits)
@@ -241,8 +243,7 @@ class TreeSplitUtilsSuite
241
243
val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
242
244
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
243
245
numClasses = 2 , Map .empty)
244
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
245
- splits)
246
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
246
247
247
248
// Choose split, verify that it has expected threshold
248
249
val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
@@ -261,7 +262,7 @@ class TreeSplitUtilsSuite
261
262
}
262
263
263
264
test(" chooseContinuousSplit: return bad stats if we should not split" ) {
264
- // Construct data for continuous feature
265
+ // Construct data for continuous feature; all points have label 0
265
266
val featureIndex = 0
266
267
val thresholds = Array (0 , 1 , 2 , 3 )
267
268
val values = thresholds.indices.toArray
@@ -270,10 +271,9 @@ class TreeSplitUtilsSuite
270
271
val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
271
272
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
272
273
numClasses = 2 , Map .empty[Int , Int ])
273
- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
274
- splits)
274
+ val statsAggregator = getAggregator(metadata, values, labels, splits)
275
275
// Choose split, verify that it's invalid
276
- val (split , stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
276
+ val (_ , stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
277
277
featureIndex, splits)
278
278
assert(! stats.valid)
279
279
}
0 commit comments