-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-3162] [MLlib] Add local tree training for decision tree regressors #19433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…calTreeDataSuite): * TrainingInfo: primary local tree training data structure, contains all information required to describe state of algorithm at any point during learning * FeatureVector: Stores data for an individual feature as an Array[Int]
…oth local & distributed training: * AggUpdateUtils: Helper methods for updating sufficient stats for a given node * ImpurityUtils: Helper methods for impurity-related calcluations during node split decisions * SplitUtils: Helper methods for choosing splits given sufficient stats NOTE: Both ImpurityUtils and SplitUtils primarily contain code taken from RandomForest.scala, with slight modifications. Tests for SplitUtils are contained in the next commit.
* TreeSplitUtilsSuite: Test suite for SplitUtils * TreeTests: Add utility method (getMetadata) for TreeSplitUtilsSuite Also add methods used by these tests in LocalDecisionTree.scala, RandomForest.scala
…lit calculations
@WeichenXu123 would you be able to take an initial look at this? |
val numFeatures = rowStore(0).length | ||
require(numFeatures > 0, "Local decision tree training requires numFeatures > 0.") | ||
// Return the transpose of the rowStore matrix | ||
0.until(numFeatures).map { colIdx => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: replace this with an in-place matrix transpose for memory efficiency.
@smurching Does it still WIP ? If done remove "[WIP]", I will begin review, thanks! |
Thanks! I'll remove the WIP. To clear things up for the future, I'd thought [WIP] was the appropriate tag for a PR that's ready for review but not ready to be merged (based on https://spark.apache.org/contributing.html) -- have we stopped using the WIP tag? |
add to whitelist |
Test build #82557 has finished for PR 19433 at commit
|
The failing tests (in a) splits that have 0 gain differently from b) splits that fail to achieve user-specified minimum gain ( Previously we'd create a leaf node with valid impurity stats in case a) and invalid impurity stats in case b). This PR creates a leaf node with invalid impurity stats in both cases. As a fix I'd suggest creating a This will keep the process of determining split validity simple (just check |
…ranspose in LocalDecisionTreeUtils. Changes made to fix tests: * Return correct impurity stats for splits that achieved a gain of 0 but didn't violate user-specified constraints on min info gain or min instances per node * Previously, ImpurityStats.impurity was set incorrectly in ImpurityStats.getInvalidImpurityStats(), requiring a correction in LearningNode.toNode. This commit fixes the issue by directly setting impurity = -1 in getInvalidSplits()
Test build #82570 has finished for PR 19433 at commit
|
The failing SparkR test (which compares In this PR we recompute parent node impurity stats when considering each split for a feature, instead of computing parent impurity stats once per feature (see this by comparing The process of repeatedly computing parent impurity stats results in slightly different impurity values at each iteration due to Double precision limitations. This in turn can cause different splits to be selected (e.g. if two splits have mathematically equal gains, Double precision limitations can cause one split to have a larger/smaller gain than the other, influencing tiebreaking). |
…ats during best split selection
Test build #82652 has finished for PR 19433 at commit
|
I made a rough pass. I have only a few issues for now, I haven't go into code details:
|
Thanks for the comments!
|
Sorry, realized I conflated feature subsampling and |
Test build #82717 has finished for PR 19433 at commit
|
Test build #82721 has finished for PR 19433 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a deeper pass review, Later I will put more thoughts on the columnar feature storage design. Thanks!
// gives us the split bit value for each instance based on the instance's index. | ||
// We copy our feature values into @tempVals and @tempIndices either: | ||
// 1) in the [from, numLeftRows) range if the bit is false, or | ||
// 2) in the [numBitsNotSet, to) range if the bit is true. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although numLeftRows
== numBitsNotSet
, it is better to keep them the same in doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change this, thanks for the catch!
// Filter out leaf nodes from the previous iteration | ||
val activeNonLeafs = activeNodes.zipWithIndex.filterNot(_._1.isLeaf) | ||
// Iterate over the active nodes in the current level. | ||
activeNonLeafs.flatMap { case (node: LearningNode, nodeIndex: Int) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The var name activeNodes
, activeNonLeafs
are not accurate I think.
Here the activeNodes
are actually "next level nodes", including "probably splittable nodes(active nodes)" and "leaf nodes".
val activeNodes: Array[LearningNode] = | ||
computeBestSplits(trainingInfo, labels, metadata, splits) | ||
// Filter active node periphery by impurity. | ||
val estimatedRemainingActive = activeNodes.count(_.stats.impurity > 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use activeNodes.count(_.isLeaf)
instead. Make code simpler.
And as mentioned above, the activeNodes
is better to be renamed to nextLevelNodes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on using isLeaf
instead of checking for positive impurity, thanks for the suggestion.
AFAICT at this point in the code activeNodes
actually does refer to the nodes in the current level; the children of nodes in activeNodes
are the nodes in the next level, and are returned by computeBestSplits
. I forgot to include the return type of computeBestSplit
in its method signature, which probably made this more confusing - my mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Sorry for confusing you. The change what I said was changing to:
val nextLevelNodes: Array[LearningNode] =
computeBestSplits(trainingInfo, labels, metadata, splits)
Does it look more reasonable ?
And change the member name in trainingInfo
:
TrainingInfo.activeNodes
==> TrainingInfo.currentLevelNodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha agreed on the naming change, how about currentLevelActiveNodes
? Since only the non-leaf nodes from the current level are included.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait... I check the code here: trainingInfo = trainingInfo.update(splits, activeNodes)
So it seems you do not filter out the leaf node from the "activeNodes"(which is actually the nextLevelNode
I mentioned above).
So I think TrainingInfo.activeNodes
is still possible to contains leaf node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh true -- I'll reword the doc for currentLevelActiveNodes
to say:
* @param currentLevelActiveNodes Nodes which are active (could still be split).
* Inactive nodes are known to be leaves in the final tree.
*/ | ||
private[impl] case class TrainingInfo( | ||
columns: Array[FeatureVector], | ||
instanceWeights: Array[Double], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The instanceWeights
will never be updated in each iteration, so why put it in the TrainingInfo
structure ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I'll move instanceWeights
outside TrainingInfo
*/ | ||
private[impl] def updateParentImpurity( | ||
statsAggregator: DTStatsAggregator, | ||
col: FeatureVector, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, updateParentImpurity
has no relation with any feature column, but here you pass in the feature
column only want to use the indices
array, passing anyone feature column will be OK. But, this looks weird, maybe it can be better designed.
label: Double, | ||
featureIndex: Int, | ||
featureIndexIdx: Int, | ||
splits: Array[Array[Split]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to pass in the featureSplit: Array[Split]
, don't pass all splits for all features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I'll make this change.
from: Int, | ||
to: Int, | ||
split: Split, | ||
allSplits: Array[Array[Split]]): BitSet = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, you only need to pass in the featureSplit: Array[Split], don't pass all splits for all features.
@smurching I found some issues and have some thoughts on the columnar features format:
|
* Move instanceWeights outside TrainingInfo * Only pass a single array of splits (instead of an array of arrays of splits) when possible
Test build #83464 has finished for PR 19433 at commit
|
jenkins retest this please |
Test build #83503 has finished for PR 19433 at commit
|
Test build #83507 has finished for PR 19433 at commit
|
CC @dbtsai in case you're interested b/c of Sequoia forests |
Test build #3983 has finished for PR 19433 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done with pass over the parts which refactor elements of RandomForest.scala into utility classes. Will review more after updates!
agg: DTStatsAggregator, | ||
featureValue: Int, | ||
label: Double, | ||
featureIndex: Int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
featureIndex is not used
private[impl] def getNonConstantFeatures( | ||
metadata: DecisionTreeMetadata, | ||
featuresForNode: Option[Array[Int]]): Seq[(Int, Int)] = { | ||
Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was there a reason to remove the use of view and withFilter here? With the output of this method going through further Seq operations, I would expect the previous implementation to be more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point when refactoring I was hitting errors caused by a stateful operation within a map
over the output of this method (IIRC the result of the map
was accessed repeatedly, causing the stateful operation to inadvertently be run multiple times).
However using withFilter
and view
now seems to work, I'll change it back :)
// Cumulative sum (scanLeft) of bin statistics. | ||
// Afterwards, binAggregates for a bin is the sum of aggregates for | ||
// that bin + all preceding bins. | ||
assert(!binAggregates.metadata.isUnordered(featureIndex)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this (If there's any chance of this, then we should find ways to test it.)
val featureValue = categoriesSortedByCentroid(splitIndex) | ||
val leftChildStats = | ||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) | ||
val rightChildStats = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can be moved outside of the map. Actually, this is the parentCalc, right? So if it's not available, parentCalc can be computed beforehand outside of the map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly, it's the parentCalc minus the left child stats. Since ImpurityCalculator.subtract()
updates the impurity calculator in place, we call binAggregates.getParentImpurityCalculator()
to get a copy of the parent impurity calculator, then subtract the left child stats.
// Unordered categorical feature | ||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) | ||
val numSplits = binAggregates.metadata.numSplits(featureIndex) | ||
var parentCalc = parentCalculator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to calculate the parentCalc right away here, if needed. That seems possible just by taking the first candidate split. Then we could simplify calculateImpurityStats by not passing in parentCalc as an option.
val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, categoryStats) | ||
(featureValue, centroid) | ||
} | ||
// TODO(smurching): How to handle logging statements like these? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the issue? You should be able to call logDebug if this object inherits from org.apache.spark.internal.Logging
node: LearningNode): (Split, ImpurityStats) = { | ||
val validFeatureSplits = getNonConstantFeatures(binAggregates.metadata, featuresForNode) | ||
// For each (feature, split), calculate the gain, and select the best (feature, split). | ||
val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to check: Will node.stats == null for the top level for sure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so, the nodes at the top level are created (RandomForest.scala:178) with LearningNode.emptyNode
, which sets node.stats = null
.
I could change this to check node depth (via node index), but if we're planning on deprecating node indices in the future it might be best not to.
@@ -112,7 +113,7 @@ private[spark] object ImpurityStats { | |||
* minimum number of instances per node. | |||
*/ | |||
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { | |||
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), | |||
new ImpurityStats(Double.MinValue, impurity = -1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Why -1 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this to be -1 here since node impurity would eventually get set to -1 anyways when LearningNodes
with invalid ImpurityStats
were converted into decision tree leaf nodes (see LearningNode.toNode
)
…ately reflect what the method actually does). Switch back to view, withFilter in getNonConstantFeatures
…e the map call in chooseUnorderedCategoricalSplit, orderedSplitHelper
Test build #83873 has finished for PR 19433 at commit
|
Test build #83874 has finished for PR 19433 at commit
|
Test build #97977 has finished for PR 19433 at commit
|
Test build #101549 has finished for PR 19433 at commit
|
Test build #101588 has finished for PR 19433 at commit
|
Is this still a thing you are actively working on? |
Thank you for your contribution! We've used this code extensively as a basis for our @cisco/oraf library, which incorporates local training into the existing decision tree and random forest APIs, and managed to significantly speed-up the training process. |
That's cool @rstarosta . Does having it in a library meet the needs of folks and we can close this PR? |
Test build #110569 has finished for PR 19433 at commit
|
We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable. |
What changes were proposed in this pull request?
Overview
This PR adds local tree training for decision tree regressors as a first step for addressing SPARK-3162: train decision trees locally when possible.
See this design doc (in particular the local tree training section) for detailed discussion of the proposed changes.
Distributed training logic has been refactored but only minimally modified; the local tree training implementation leverages existing distributed training logic for computing impurities and splits. This shared logic has been refactored into
...Utils
objects (e.g.SplitUtils.scala
,ImpurityUtils.scala
).How to Review
Each commit in this PR adds non-overlapping functionality, so the PR can be reviewed commit-by-commit.
Changes introduced by each commit:
FeatureVector
,TrainingInfo
)SplitUtils
,ImpurityUtils
,AggUpdateUtils
), largely copied from existing distributed training code inRandomForest.scala
.TreeSplitUtilsSuite
)RandomForest.scala
to depend on the utility methods introduced in 2.LocalDecisionTree
)LocalTreeUnitSuite
,LocalTreeIntegrationSuite
)How was this patch tested?
No existing tests were modified. The following new tests were added (also described above):
LocalTreeDataSuite
,LocalTreeUtilsSuite
)TreeSplitUtilsSuite
)LocalTreeUnitSuite
)LocalTreeIntegrationSuite
)