Skip to content

Commit 5bcccda

Browse files
committed
Update TreeSplitUtilsSuite
1 parent b6291e1 commit 5bcccda

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,13 @@ class TreeSplitUtilsSuite
186186
test("chooseUnorderedCategoricalSplit: basic case") {
187187
val featureIndex = 0
188188
// Construct data for unordered categorical feature
189-
// label: 0 --> values: 1
190-
// label: 1 --> values: 0, 2
191-
// label: 2 --> values: 2
192-
// Expected split: feature value 1 on the left, values (0, 2) on the right
193-
val values = Array(1, 1, 0, 2, 2)
189+
// label: 0 --> values: 0, 1
190+
// label: 1 --> values: 2, 3
191+
// label: 2 --> values: 2, 2, 4
192+
// Expected split: feature values (0, 1) on the left, values (2, 3, 4) on the right
193+
val values = Array(0, 1, 2, 3, 2, 2, 4)
194+
val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)
194195
val featureArity = values.max + 1
195-
val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0)
196196
// Construct DTStatsAggregator, compute sufficient stats
197197
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
198198
numClasses = 3, Map(featureIndex -> featureArity))
@@ -206,14 +206,14 @@ class TreeSplitUtilsSuite
206206
split match {
207207
case s: CategoricalSplit =>
208208
assert(s.featureIndex === featureIndex)
209-
assert(s.leftCategories.toSet === Set(1.0))
210-
assert(s.rightCategories.toSet === Set(0.0, 2.0))
209+
assert(s.leftCategories.toSet === Set(0.0, 1.0))
210+
assert(s.rightCategories.toSet === Set(2.0, 3.0, 4.0))
211211
case _ =>
212212
throw new AssertionError(
213213
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
214214
}
215215
validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0, 0.0),
216-
expectedRightStats = Array(0.0, 2.0, 1.0))
216+
expectedRightStats = Array(0.0, 2.0, 3.0))
217217
}
218218

219219
test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") {

0 commit comments

Comments
 (0)