Skip to content

Commit b93f9f3

Browse files
committed
Add TreeSplitUtilsSuite, refactor it to not depend on any local tree training code
1 parent 320c32e commit b93f9f3

File tree

1 file changed

+54
-29
lines changed

1 file changed

+54
-29
lines changed

mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
class TreeSplitUtilsSuite
2929
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3030

31+
/**
32+
* Iterate over feature values and labels for a specific (node, feature), updating stats
33+
* aggregator for the current node.
34+
*/
35+
private[impl] def updateAggregator(
36+
statsAggregator: DTStatsAggregator,
37+
featureIndex: Int,
38+
values: Array[Int],
39+
indices: Array[Int],
40+
instanceWeights: Array[Double],
41+
labels: Array[Double],
42+
from: Int,
43+
to: Int,
44+
featureIndexIdx: Int,
45+
featureSplits: Array[Split]): Unit = {
46+
val metadata = statsAggregator.metadata
47+
from.until(to).foreach { idx =>
48+
val rowIndex = indices(idx)
49+
if (metadata.isUnordered(featureIndex)) {
50+
AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
51+
featureIndex = featureIndex, featureIndexIdx, featureSplits,
52+
instanceWeight = instanceWeights(rowIndex))
53+
} else {
54+
AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
55+
featureIndexIdx, instanceWeight = instanceWeights(rowIndex))
56+
}
57+
}
58+
}
59+
3160
/**
3261
* Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
3362
* with the data from the specified training points.
@@ -40,12 +69,13 @@ class TreeSplitUtilsSuite
4069
labels: Array[Double],
4170
featureSplits: Array[Split]): DTStatsAggregator = {
4271

72+
val featureIndex = 0
4373
val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None)
4474
val instanceWeights = Array.fill[Double](values.length)(1.0)
4575
val indices = values.indices.toArray
4676
AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
47-
LocalDecisionTree.updateAggregator(statsAggregator, col, indices, instanceWeights, labels,
48-
from, to, col.featureIndex, featureSplits)
77+
updateAggregator(statsAggregator, featureIndex = 0, values, indices, instanceWeights, labels,
78+
from, to, featureIndex, featureSplits)
4979
statsAggregator
5080
}
5181

@@ -73,34 +103,34 @@ class TreeSplitUtilsSuite
73103
test("chooseSplit: choose correct type of split (continuous split)") {
74104
// Construct (binned) continuous data
75105
val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0)
76-
val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6))
106+
val values = Array(8, 1, 1, 2, 3, 5, 6)
107+
val featureIndex = 0
77108
// Get an array of continuous splits corresponding to values in our binned data
78109
val splits = TreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0)
79110
// Construct DTStatsAggregator, compute sufficient stats
80111
val metadata = TreeTests.getMetadata(numExamples = 7,
81112
numFeatures = 1, numClasses = 2, Map.empty)
82-
val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits)
113+
val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits)
83114
// Choose split, check that it's a valid ContinuousSplit
84-
val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, col.featureIndex,
85-
col.featureIndex, splits)
115+
val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex,
116+
splits)
86117
assert(stats1.valid && split1.isInstanceOf[ContinuousSplit])
87118
}
88119

89120
test("chooseSplit: choose correct type of split (categorical split)") {
90121
// Construct categorical data
91122
val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0)
92-
val featureIndex = 0
93123
val featureArity = 3
94124
val values = Array(0, 0, 1, 1, 1, 2, 2)
95-
val col = FeatureColumn(featureIndex, values)
125+
val featureIndex = 0
96126
// Construct DTStatsAggregator, compute sufficient stats
97127
val metadata = TreeTests.getMetadata(numExamples = 7,
98128
numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity))
99129
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
100-
val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits)
130+
val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits)
101131
// Choose split, check that it's a valid categorical split
102132
val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator,
103-
featureIndex = col.featureIndex, featureIndexIdx = col.featureIndex,
133+
featureIndex = featureIndex, featureIndexIdx = featureIndex,
104134
featureSplits = splits)
105135
assert(stats2.valid && split2.isInstanceOf[CategoricalSplit])
106136
}
@@ -117,16 +147,14 @@ class TreeSplitUtilsSuite
117147
// Construct FeatureVector to store categorical data
118148
val featureArity = values.max + 1
119149
val arityMap = Map[Int, Int](featureIndex -> featureArity)
120-
val col = FeatureColumn(featureIndex = 0, values = values)
121150
// Construct DTStatsAggregator, compute sufficient stats
122151
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
123152
numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty))
124-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length,
153+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
125154
labels, featureSplits = Array.empty)
126155
// Choose split
127156
val (split, stats) =
128-
SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex,
129-
col.featureIndex)
157+
SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
130158
// Verify that split has the expected left-side/right-side categories
131159
val expectedRightCategories = Range(0, featureArity)
132160
.filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray
@@ -156,15 +184,14 @@ class TreeSplitUtilsSuite
156184
val values = Array(0, 0, 1, 2, 2, 2, 2)
157185
val featureArity = values.max + 1
158186
val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
159-
val col = FeatureColumn(featureIndex, values)
160187
// Construct DTStatsAggregator, compute sufficient stats
161188
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
162189
numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty))
163-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length,
190+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
164191
labels, featureSplits = Array.empty)
165192
// Choose split, verify that it's invalid
166-
val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex,
167-
col.featureIndex)
193+
val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
194+
featureIndex)
168195
assert(!stats.valid)
169196
}
170197

@@ -177,17 +204,16 @@ class TreeSplitUtilsSuite
177204
val values = Array(1, 1, 0, 2, 2)
178205
val featureArity = values.max + 1
179206
val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0)
180-
val col = FeatureColumn(featureIndex, values)
181207
// Construct DTStatsAggregator, compute sufficient stats
182208
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
183209
numClasses = 3, Map(featureIndex -> featureArity))
184210
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
185-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length,
211+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length,
186212
labels, splits)
187213
// Choose split
188214
val (split, stats) =
189-
SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, col.featureIndex,
190-
col.featureIndex, splits)
215+
SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
216+
splits)
191217
// Verify that split has the expected left-side/right-side categories
192218
split match {
193219
case s: CategoricalSplit =>
@@ -208,12 +234,12 @@ class TreeSplitUtilsSuite
208234
val featureArity = 4
209235
val values = Array(3, 1, 0, 2, 2)
210236
val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0)
211-
val col = FeatureColumn(featureIndex, values)
212237
// Construct DTStatsAggregator, compute sufficient stats
213238
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
214239
numClasses = 2, Map(featureIndex -> featureArity))
215240
val splits = RandomForest.findUnorderedSplits(metadata, featureIndex)
216-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits)
241+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
242+
splits)
217243
// Choose split, verify that it's invalid
218244
val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
219245
featureIndex, splits)
@@ -226,13 +252,12 @@ class TreeSplitUtilsSuite
226252
val thresholds = Array(0, 1, 2, 3)
227253
val values = thresholds.indices.toArray
228254
val labels = Array(0.0, 0.0, 1.0, 1.0)
229-
val col = FeatureColumn(featureIndex = featureIndex, values = values)
230-
231255
// Construct DTStatsAggregator, compute sufficient stats
232256
val splits = TreeTests.getContinuousSplits(thresholds, featureIndex)
233257
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
234258
numClasses = 2, Map.empty)
235-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits)
259+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
260+
splits)
236261

237262
// Choose split, verify that it has expected threshold
238263
val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex,
@@ -256,12 +281,12 @@ class TreeSplitUtilsSuite
256281
val thresholds = Array(0, 1, 2, 3)
257282
val values = thresholds.indices.toArray
258283
val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0)
259-
val col = FeatureColumn(featureIndex = featureIndex, values = values)
260284
// Construct DTStatsAggregator, compute sufficient stats
261285
val splits = TreeTests.getContinuousSplits(thresholds, featureIndex)
262286
val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1,
263287
numClasses = 2, Map.empty[Int, Int])
264-
val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits)
288+
val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels,
289+
splits)
265290
// Choose split, verify that it's invalid
266291
val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex,
267292
featureIndex, splits)

0 commit comments

Comments
 (0)