Skip to content

Commit c9a8e01

Browse files
committed
Fix test bug where instanceWeights weren't properly passed to update methods
1 parent cc6a30c commit c9a8e01

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[impl] object AggUpdateUtils {
5252
featureIndex: Int,
5353
featureIndexIdx: Int,
5454
splits: Array[Array[Split]],
55-
instanceWeight: Double = 1.0): Unit = {
55+
instanceWeight: Double): Unit = {
5656
val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
5757
// Each unordered split has a corresponding bin for impurity stats of data points that fall
5858
// onto the left side of the split. For each unordered split, update left-side bin if applicable
@@ -75,7 +75,7 @@ private[impl] object AggUpdateUtils {
7575
label: Double,
7676
featureIndex: Int,
7777
featureIndexIdx: Int,
78-
instanceWeight: Double = 1.0): Unit = {
78+
instanceWeight: Double): Unit = {
7979
// The bin index of an ordered feature is just the feature value itself
8080
val binIndex = featureValue
8181
agg.update(featureIndexIdx, binIndex, label, instanceWeight)

mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ private[ml] object LocalDecisionTree {
5151
= LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.datum.binnedFeatures))
5252
val labels = input.map(_.datum.label)
5353

54-
// Train classifier if numClasses is between 1 and 32, otherwise fit a regression model
55-
// on the dataset
54+
// Fit a regression model on the dataset, throwing an error if metadata indicates that
55+
// we should train a classifier.
56+
// TODO: Add support for training classifiers
5657
if (metadata.numClasses > 1 && metadata.numClasses <= 32) {
5758
throw new UnsupportedOperationException("Local training of a decision tree classifier is " +
5859
"unsupported; currently, only regression is supported")
@@ -137,13 +138,13 @@ private[ml] object LocalDecisionTree {
137138
from.until(to).foreach { idx =>
138139
val rowIndex = col.indices(idx)
139140
AggUpdateUtils.updateUnorderedFeature(statsAggregator, col.values(idx), labels(rowIndex),
140-
featureIndex = col.featureIndex, featureIndexIdx, splits)
141+
featureIndex = col.featureIndex, featureIndexIdx, splits, instanceWeight = 1.0)
141142
}
142143
} else {
143144
from.until(to).foreach { idx =>
144145
val rowIndex = col.indices(idx)
145146
AggUpdateUtils.updateOrderedFeature(statsAggregator, col.values(idx), labels(rowIndex),
146-
featureIndex = col.featureIndex, featureIndexIdx)
147+
featureIndex = col.featureIndex, featureIndexIdx, instanceWeight = 1.0)
147148
}
148149
}
149150
}

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,13 @@ private[spark] object RandomForest extends Logging {
281281
if (unorderedFeatures.contains(featureIndex)) {
282282
AggUpdateUtils.updateUnorderedFeature(agg,
283283
featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label,
284-
featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, splits = splits)
284+
featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, splits = splits,
285+
instanceWeight = instanceWeight)
285286
} else {
286287
AggUpdateUtils.updateOrderedFeature(agg,
287288
featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label,
288-
featureIndex = featureIndex, featureIndexIdx = featureIndexIdx)
289+
featureIndex = featureIndex, featureIndexIdx = featureIndexIdx,
290+
instanceWeight = instanceWeight)
289291
}
290292
featureIndexIdx += 1
291293
}

0 commit comments

Comments
 (0)