Skip to content

Commit b6291e1

Browse files
committed
Simplify TreeSplitUtilsSuite
1 parent 31ef80b commit b6291e1

File tree

2 files changed

+59
-59
lines changed

2 files changed

+59
-59
lines changed

mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,52 +30,51 @@ class TreeSplitUtilsSuite
3030

3131
/**
3232
* Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
33-
* with the data from the specified training points.
33+
* with the data from the specified training points. Assumes a feature index of 0 and that
34+
* all training points have the same weights (1.0).
3435
*/
3536
private def getAggregator(
3637
metadata: DecisionTreeMetadata,
3738
values: Array[Int],
38-
from: Int,
39-
to: Int,
4039
labels: Array[Double],
4140
featureSplits: Array[Split]): DTStatsAggregator = {
42-
43-
val featureIndex = 0
41+
// Create stats aggregator
4442
val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None)
45-
val indices = values.indices.toArray
46-
val instanceWeights = Array.fill[Double](values.length)(1.0)
4743
// Update parent impurity stats
48-
AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
44+
val featureIndex = 0
45+
val instanceWeights = Array.fill[Double](values.length)(1.0)
46+
AggUpdateUtils.updateParentImpurity(statsAggregator, indices = values.indices.toArray,
47+
from = 0, to = values.length, instanceWeights, labels)
4948
// Update current aggregator's impurity stats
50-
from.until(to).foreach { idx =>
51-
val rowIndex = indices(idx)
49+
values.zip(labels).foreach { case (value: Int, label: Double) =>
5250
if (metadata.isUnordered(featureIndex)) {
53-
AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
54-
featureIndex = featureIndex, featureIndexIdx, featureSplits,
55-
instanceWeight = 1.0)
51+
AggUpdateUtils.updateUnorderedFeature(statsAggregator, value, label,
52+
featureIndex = featureIndex, featureIndexIdx = 0, featureSplits, instanceWeight = 1.0)
5653
} else {
57-
AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
58-
featureIndexIdx, instanceWeight = 1.0)
54+
AggUpdateUtils.updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0,
55+
instanceWeight = 1.0)
5956
}
6057
}
61-
62-
updateAggregator(statsAggregator, featureIndex = 0, featureIndexIdx = 0, values, indices,
63-
labels, from, to, featureSplits)
6458
statsAggregator
6559
}
6660

67-
/** Check that left/right impurities match what we'd expect for a split. */
61+
/**
62+
* Check that left/right impurities match what we'd expect for a split.
63+
* @param labels Labels whose impurity information should be reflected in stats
64+
* @param stats ImpurityStats object containing impurity info for the left/right sides of a split
65+
*/
6866
private def validateImpurityStats(
6967
impurity: Impurity,
7068
labels: Array[Double],
7169
stats: ImpurityStats,
7270
expectedLeftStats: Array[Double],
7371
expectedRightStats: Array[Double]): Unit = {
74-
// Verify that impurity stats were computed correctly for split
72+
// Compute impurity for our data points manually
7573
val numClasses = (labels.max + 1).toInt
7674
val fullImpurityStatsArray
7775
= Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble)
7876
val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length)
77+
// Verify that impurity stats were computed correctly for split
7978
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
8079
assert(stats.impurity === fullImpurity)
8180
assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
@@ -87,37 +86,37 @@ class TreeSplitUtilsSuite
8786

8887
test("chooseSplit: choose correct type of split (continuous split)") {
8988
// Construct (binned) continuous data
90-
val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0)
91-
val values = Array(8, 1, 1, 2, 3, 5, 6)
89+
val labels = Array(0.0, 0.0, 1.0)
90+
val values = Array(1, 2, 3)
9291
val featureIndex = 0
9392
// Get an array of continuous splits corresponding to values in our binned data
94-
val splits = TreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0)
93+
val splits = TreeTests.getContinuousSplits(thresholds = values.distinct.sorted,
94+
featureIndex = 0)
9595
// Construct DTStatsAggregator, compute sufficient stats
96-
val metadata = TreeTests.getMetadata(numExamples = 7,
97-
numFeatures = 1, numClasses = 2, Map.empty)
98-
val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits)
96+
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
97+
numClasses = 2, Map.empty)
98+
val statsAggregator = getAggregator(metadata, values, labels, splits)
9999
// Choose split, check that it's a valid ContinuousSplit
100-
val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex,
100+
val (split, stats) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex,
101101
splits)
102-
assert(stats1.valid && split1.isInstanceOf[ContinuousSplit])
102+
assert(stats.valid && split.isInstanceOf[ContinuousSplit])
103103
}
104104

105105
test("chooseSplit: choose correct type of split (categorical split)") {
106106
// Construct categorical data
107-
val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0)
107+
val labels = Array(0.0, 0.0, 1.0, 1.0, 1.0)
108108
val featureArity = 3
109-
val values = Array(0, 0, 1, 1, 1, 2, 2)
109+
val values = Array(0, 0, 1, 2, 2)
110110
val featureIndex = 0
111111
// Construct DTStatsAggregator, compute sufficient stats
112-
val metadata = TreeTests.getMetadata(numExamples = 7,
113-
numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity))
112+
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
113+
numClasses = 2, Map(featureIndex -> featureArity))
114114
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
115-
val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits)
115+
val statsAggregator = getAggregator(metadata, values, labels, splits)
116116
// Choose split, check that it's a valid categorical split
117-
val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator,
118-
featureIndex = featureIndex, featureIndexIdx = featureIndex,
119-
featureSplits = splits)
120-
assert(stats2.valid && split2.isInstanceOf[CategoricalSplit])
117+
val (split, stats) = SplitUtils.chooseSplit(statsAggregator = statsAggregator,
118+
featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits)
119+
assert(stats.valid && split.isInstanceOf[CategoricalSplit])
121120
}
122121

123122
test("chooseOrderedCategoricalSplit: basic case") {
@@ -128,15 +127,14 @@ class TreeSplitUtilsSuite
128127
expectedLeftCategories: Array[Double],
129128
expectedLeftStats: Array[Double],
130129
expectedRightStats: Array[Double]): Unit = {
130+
// Set up metadata for ordered categorical feature
131131
val featureIndex = 0
132-
// Construct FeatureVector to store categorical data
133132
val featureArity = values.max + 1
134133
val arityMap = Map[Int, Int](featureIndex -> featureArity)
135-
// Construct DTStatsAggregator, compute sufficient stats
136134
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
137135
numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty))
138-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
139-
labels, featureSplits = Array.empty)
136+
// Construct DTStatsAggregator, compute sufficient stats
137+
val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty)
140138
// Choose split
141139
val (split, stats) =
142140
SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
@@ -155,12 +153,18 @@ class TreeSplitUtilsSuite
155153
validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats)
156154
}
157155

156+
// Test a single split: The left side of our split should contain the two points with label 0,
157+
// the left side of our split should contain the five points with label 1
158158
val values = Array(0, 0, 1, 2, 2, 2, 2)
159159
val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble)
160-
testHelper(values, labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0))
160+
testHelper(values, labels1, expectedLeftCategories = Array(0.0),
161+
expectedLeftStats = Array(2.0, 0.0), expectedRightStats = Array(0.0, 5.0))
161162

163+
// Test a single split: The left side of our split should contain the three points with label 0,
164+
// the left side of our split should contain the four points with label 1
162165
val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble)
163-
testHelper(values, labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0))
166+
testHelper(values, labels2, expectedLeftCategories = Array(0.0, 1.0),
167+
expectedLeftStats = Array(3.0, 0.0), expectedRightStats = Array(0.0, 4.0))
164168
}
165169

166170
test("chooseOrderedCategoricalSplit: return bad stats if we should not split") {
@@ -172,8 +176,7 @@ class TreeSplitUtilsSuite
172176
// Construct DTStatsAggregator, compute sufficient stats
173177
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
174178
numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty))
175-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
176-
labels, featureSplits = Array.empty)
179+
val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty)
177180
// Choose split, verify that it's invalid
178181
val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
179182
featureIndex)
@@ -186,15 +189,15 @@ class TreeSplitUtilsSuite
186189
// label: 0 --> values: 1
187190
// label: 1 --> values: 0, 2
188191
// label: 2 --> values: 2
192+
// Expected split: feature value 1 on the left, values (0, 2) on the right
189193
val values = Array(1, 1, 0, 2, 2)
190194
val featureArity = values.max + 1
191195
val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0)
192196
// Construct DTStatsAggregator, compute sufficient stats
193197
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
194198
numClasses = 3, Map(featureIndex -> featureArity))
195199
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
196-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
197-
labels, splits)
200+
val statsAggregator = getAggregator(metadata, values, labels, splits)
198201
// Choose split
199202
val (split, stats) =
200203
SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
@@ -214,7 +217,7 @@ class TreeSplitUtilsSuite
214217
}
215218

216219
test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") {
217-
// Construct data for unordered categorical feature
220+
// Construct data for unordered categorical feature; all points have label 1
218221
val featureIndex = 0
219222
val featureArity = 4
220223
val values = Array(3, 1, 0, 2, 2)
@@ -223,8 +226,7 @@ class TreeSplitUtilsSuite
223226
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
224227
numClasses = 2, Map(featureIndex -> featureArity))
225228
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
226-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
227-
splits)
229+
val statsAggregator = getAggregator(metadata, values, labels, splits)
228230
// Choose split, verify that it's invalid
229231
val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
230232
featureIndex, splits)
@@ -241,8 +243,7 @@ class TreeSplitUtilsSuite
241243
val splits = TreeTests.getContinuousSplits(thresholds, featureIndex)
242244
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
243245
numClasses = 2, Map.empty)
244-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
245-
splits)
246+
val statsAggregator = getAggregator(metadata, values, labels, splits)
246247

247248
// Choose split, verify that it has expected threshold
248249
val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex,
@@ -261,7 +262,7 @@ class TreeSplitUtilsSuite
261262
}
262263

263264
test("chooseContinuousSplit: return bad stats if we should not split") {
264-
// Construct data for continuous feature
265+
// Construct data for continuous feature; all points have label 0
265266
val featureIndex = 0
266267
val thresholds = Array(0, 1, 2, 3)
267268
val values = thresholds.indices.toArray
@@ -270,10 +271,9 @@ class TreeSplitUtilsSuite
270271
val splits = TreeTests.getContinuousSplits(thresholds, featureIndex)
271272
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
272273
numClasses = 2, Map.empty[Int, Int])
273-
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
274-
splits)
274+
val statsAggregator = getAggregator(metadata, values, labels, splits)
275275
// Choose split, verify that it's invalid
276-
val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex,
276+
val (_, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex,
277277
featureIndex, splits)
278278
assert(!stats.valid)
279279
}

mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,12 @@ private[ml] object TreeTests extends SparkFunSuite {
133133

134134
/**
135135
* Returns an array of continuous splits for the feature with index featureIndex and the passed-in
136-
* set of values. Creates one continuous split per value in values.
136+
* set of threshold. Creates one continuous split per threshold in thresholds.
137137
*/
138138
private[impl] def getContinuousSplits(
139-
values: Array[Int],
139+
thresholds: Array[Int],
140140
featureIndex: Int): Array[Split] = {
141-
val splits = values.sorted.map {
141+
val splits = thresholds.sorted.map {
142142
new ContinuousSplit(featureIndex, _).asInstanceOf[Split]
143143
}
144144
splits

0 commit comments

Comments
 (0)