@@ -186,13 +186,13 @@ class TreeSplitUtilsSuite
186
186
test(" chooseUnorderedCategoricalSplit: basic case" ) {
187
187
val featureIndex = 0
188
188
// Construct data for unordered categorical feature
189
- // label: 0 --> values: 1
190
- // label: 1 --> values: 0, 2
191
- // label: 2 --> values: 2
192
- // Expected split: feature value 1 on the left, values (0, 2) on the right
193
- val values = Array (1 , 1 , 0 , 2 , 2 )
189
+ // label: 0 --> values: 0, 1
190
+ // label: 1 --> values: 2, 3
191
+ // label: 2 --> values: 2, 2, 4
192
+ // Expected split: feature values (0, 1) on the left, values (2, 3, 4) on the right
193
+ val values = Array (0 , 1 , 2 , 3 , 2 , 2 , 4 )
194
+ val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 , 2.0 , 2.0 )
194
195
val featureArity = values.max + 1
195
- val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 )
196
196
// Construct DTStatsAggregator, compute sufficient stats
197
197
val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
198
198
numClasses = 3 , Map (featureIndex -> featureArity))
@@ -206,14 +206,14 @@ class TreeSplitUtilsSuite
206
206
split match {
207
207
case s : CategoricalSplit =>
208
208
assert(s.featureIndex === featureIndex)
209
- assert(s.leftCategories.toSet === Set (1.0 ))
210
- assert(s.rightCategories.toSet === Set (0 .0 , 2 .0 ))
209
+ assert(s.leftCategories.toSet === Set (0.0 , 1.0 ))
210
+ assert(s.rightCategories.toSet === Set (2.0 , 3 .0 , 4 .0 ))
211
211
case _ =>
212
212
throw new AssertionError (
213
213
s " Expected CategoricalSplit but got ${split.getClass.getSimpleName}" )
214
214
}
215
215
validateImpurityStats(Entropy , labels, stats, expectedLeftStats = Array (2.0 , 0.0 , 0.0 ),
216
- expectedRightStats = Array (0.0 , 2.0 , 1 .0 ))
216
+ expectedRightStats = Array (0.0 , 2.0 , 3 .0 ))
217
217
}
218
218
219
219
test(" chooseUnorderedCategoricalSplit: return bad stats if we should not split" ) {
0 commit comments