Skip to content

Commit f2e3fbd

Browse files
committed
Local tree training part 1 (refactor RandomForest.scala into utility classes)
1 parent 1e6f760 commit f2e3fbd

File tree

6 files changed

+501
-238
lines changed

6 files changed

+501
-238
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,10 @@ private[tree] class LearningNode(
276276
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
277277
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
278278
} else {
279-
if (stats.valid) {
280-
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
281-
stats.impurityCalculator)
282-
} else {
283-
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
284-
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
285-
}
286-
279+
assert(stats != null, "Unknown error during Decision Tree learning. Could not convert " +
280+
"LearningNode to Node")
281+
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
282+
stats.impurityCalculator)
287283
}
288284
}
289285

@@ -334,7 +330,7 @@ private[tree] object LearningNode {
334330
id: Int,
335331
isLeaf: Boolean,
336332
stats: ImpurityStats): LearningNode = {
337-
new LearningNode(id, None, None, None, false, stats)
333+
new LearningNode(id, None, None, None, isLeaf, stats)
338334
}
339335

340336
/** Create an empty node with the given node index. Values must be set later on. */
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.tree.impl
19+
20+
import org.apache.spark.ml.tree.Split
21+
22+
/**
23+
* Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training.
24+
*/
25+
private[impl] object AggUpdateUtils {
26+
27+
/**
28+
* Updates the parent node stats of the passed-in impurity aggregator with the labels
29+
* corresponding to the feature values at indices [from, to).
30+
* @param indices Array of row indices for feature values; indices(i) = row index of the ith
31+
* feature value
32+
*/
33+
private[impl] def updateParentImpurity(
34+
statsAggregator: DTStatsAggregator,
35+
indices: Array[Int],
36+
from: Int,
37+
to: Int,
38+
instanceWeights: Array[Double],
39+
labels: Array[Double]): Unit = {
40+
from.until(to).foreach { idx =>
41+
val rowIndex = indices(idx)
42+
val label = labels(rowIndex)
43+
statsAggregator.updateParent(label, instanceWeights(rowIndex))
44+
}
45+
}
46+
47+
/**
48+
* Update aggregator for an (unordered feature, label) pair
49+
* @param featureSplits Array of splits for the current feature
50+
*/
51+
private[impl] def updateUnorderedFeature(
52+
agg: DTStatsAggregator,
53+
featureValue: Int,
54+
label: Double,
55+
featureIndex: Int,
56+
featureIndexIdx: Int,
57+
featureSplits: Array[Split],
58+
instanceWeight: Double): Unit = {
59+
val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
60+
// Each unordered split has a corresponding bin for impurity stats of data points that fall
61+
// onto the left side of the split. For each unordered split, update left-side bin if applicable
62+
// for the current data point.
63+
val numSplits = agg.metadata.numSplits(featureIndex)
64+
var splitIndex = 0
65+
while (splitIndex < numSplits) {
66+
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
67+
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight)
68+
}
69+
splitIndex += 1
70+
}
71+
}
72+
73+
/** Update aggregator for an (ordered feature, label) pair */
74+
private[impl] def updateOrderedFeature(
75+
agg: DTStatsAggregator,
76+
featureValue: Int,
77+
label: Double,
78+
featureIndexIdx: Int,
79+
instanceWeight: Double): Unit = {
80+
// The bin index of an ordered feature is just the feature value itself
81+
val binIndex = featureValue
82+
agg.update(featureIndexIdx, binIndex, label, instanceWeight)
83+
}
84+
85+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.tree.impl
19+
20+
import org.apache.spark.mllib.tree.impurity._
21+
import org.apache.spark.mllib.tree.model.ImpurityStats
22+
23+
/** Helper methods for impurity-related calculations during node split decisions. */
24+
private[impl] object ImpurityUtils {
25+
26+
/**
27+
* Get impurity calculator containing statistics for all labels for rows corresponding to
28+
* feature values in [from, to).
29+
* @param indices indices(i) = row index corresponding to ith feature value
30+
*/
31+
private[impl] def getParentImpurityCalculator(
32+
metadata: DecisionTreeMetadata,
33+
indices: Array[Int],
34+
from: Int,
35+
to: Int,
36+
instanceWeights: Array[Double],
37+
labels: Array[Double]): ImpurityCalculator = {
38+
// Compute sufficient stats (e.g. label counts) for all data at the current node,
39+
// store result in currNodeStatsAgg.parentStats so that we can share it across
40+
// all features for the current node
41+
val currNodeStatsAgg = new DTStatsAggregator(metadata, featureSubset = None)
42+
AggUpdateUtils.updateParentImpurity(currNodeStatsAgg, indices, from, to,
43+
instanceWeights, labels)
44+
currNodeStatsAgg.getParentImpurityCalculator()
45+
}
46+
47+
/**
48+
* Calculate the impurity statistics for a given (feature, split) based upon left/right
49+
* aggregates.
50+
*
51+
* @param parentImpurityCalculator An ImpurityCalculator containing the impurity stats
52+
* of the node currently being split.
53+
* @param leftImpurityCalculator left node aggregates for this (feature, split)
54+
* @param rightImpurityCalculator right node aggregate for this (feature, split)
55+
* @param metadata learning and dataset metadata for DecisionTree
56+
* @return Impurity statistics for this (feature, split)
57+
*/
58+
private[impl] def calculateImpurityStats(
59+
parentImpurityCalculator: ImpurityCalculator,
60+
leftImpurityCalculator: ImpurityCalculator,
61+
rightImpurityCalculator: ImpurityCalculator,
62+
metadata: DecisionTreeMetadata): ImpurityStats = {
63+
64+
val impurity: Double = parentImpurityCalculator.calculate()
65+
66+
val leftCount = leftImpurityCalculator.count
67+
val rightCount = rightImpurityCalculator.count
68+
69+
val totalCount = leftCount + rightCount
70+
71+
// If left child or right child doesn't satisfy minimum instances per node,
72+
// then this split is invalid, return invalid information gain stats.
73+
if ((leftCount < metadata.minInstancesPerNode) ||
74+
(rightCount < metadata.minInstancesPerNode)) {
75+
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
76+
}
77+
78+
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
79+
val rightImpurity = rightImpurityCalculator.calculate()
80+
81+
val leftWeight = leftCount / totalCount.toDouble
82+
val rightWeight = rightCount / totalCount.toDouble
83+
84+
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
85+
// If information gain doesn't satisfy minimum information gain,
86+
// then this split is invalid, return invalid information gain stats.
87+
if (gain < metadata.minInfoGain) {
88+
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
89+
}
90+
91+
// If information gain is non-positive but doesn't violate the minimum info gain constraint,
92+
// return a stats object with correct values but valid = false to indicate that we should not
93+
// split.
94+
if (gain <= 0) {
95+
return new ImpurityStats(gain, impurity, parentImpurityCalculator, leftImpurityCalculator,
96+
rightImpurityCalculator, valid = false)
97+
}
98+
99+
100+
new ImpurityStats(gain, impurity, parentImpurityCalculator,
101+
leftImpurityCalculator, rightImpurityCalculator)
102+
}
103+
104+
/**
105+
* Given an impurity aggregator containing label statistics for a given (node, feature, bin),
106+
* returns the corresponding "centroid", used to order bins while computing best splits.
107+
*
108+
* @param metadata learning and dataset metadata for DecisionTree
109+
*/
110+
private[impl] def getCentroid(
111+
metadata: DecisionTreeMetadata,
112+
binStats: ImpurityCalculator): Double = {
113+
114+
if (binStats.count != 0) {
115+
if (metadata.isMulticlass) {
116+
// multiclass classification
117+
// For categorical features in multiclass classification,
118+
// the bins are ordered by the impurity of their corresponding labels.
119+
binStats.calculate()
120+
} else if (metadata.isClassification) {
121+
// binary classification
122+
// For categorical features in binary classification,
123+
// the bins are ordered by the count of class 1.
124+
binStats.stats(1)
125+
} else {
126+
// regression
127+
// For categorical features in regression and binary classification,
128+
// the bins are ordered by the prediction.
129+
binStats.predict
130+
}
131+
} else {
132+
Double.MaxValue
133+
}
134+
}
135+
}

0 commit comments

Comments
 (0)