Skip to content

Commit c8b108c

Browse files
rakeshkashyap123Rakesh Kashyap Hanasoge Padmanabha
andauthored
Skip anchored and derived features (#1052)
* Skip adding the Anchored and derived features in feature join * Add complex test case * Add comments * minor cosmetic issue * Method name change --------- Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <[email protected]>
1 parent daec247 commit c8b108c

File tree

6 files changed

+186
-14
lines changed

6 files changed

+186
-14
lines changed

feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/sources/FeatureGroupsUpdater.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,49 @@ private[offline] class FeatureGroupsUpdater {
110110
(updatedFeatureGroups, updatedKeyTaggedFeatures)
111111
}
112112

113+
/**
114+
* Update the feature groups (for Feature join) based on feature missing features. Few anchored features can be missing if the feature data
115+
* is not present. Remove those anchored features, and also the corresponding derived feature which are dependent on it.
116+
*
117+
* @param featureGroups
118+
* @param allStageFeatures
119+
* @param keyTaggedFeatures
120+
* @return
121+
*/
122+
def removeMissingFeatures(featureGroups: FeatureGroups, allAnchoredFeaturesWithData: Seq[String],
123+
keyTaggedFeatures: Seq[JoiningFeatureParams]): (FeatureGroups, Seq[JoiningFeatureParams]) = {
124+
125+
// We need to add the window agg features to it as they are also considered anchored features.
126+
val updatedAnchoredFeatures = featureGroups.allAnchoredFeatures.filter(featureRow =>
127+
allAnchoredFeaturesWithData.contains(featureRow._1)) ++ featureGroups.allWindowAggFeatures ++ featureGroups.allPassthroughFeatures
128+
129+
val updatedSeqJoinFeature = featureGroups.allSeqJoinFeatures.filter(seqJoinFeature => {
130+
// Find the constituent anchored features for every derived feature
131+
val allAnchoredFeaturesInDerived = seqJoinFeature._2.consumedFeatureNames.map(_.getFeatureName)
132+
val containsFeature: Seq[Boolean] = allAnchoredFeaturesInDerived.map(feature => updatedAnchoredFeatures.contains(feature))
133+
!containsFeature.contains(false)
134+
})
135+
136+
// Iterate over the derived features and remove the derived features which contains these anchored features.
137+
val updatedDerivedFeatures = featureGroups.allDerivedFeatures.filter(derivedFeature => {
138+
// Find the constituent anchored features for every derived feature
139+
val allAnchoredFeaturesInDerived = derivedFeature._2.consumedFeatureNames.map(_.getFeatureName)
140+
val containsFeature: Seq[Boolean] = allAnchoredFeaturesInDerived.map(feature => updatedAnchoredFeatures.contains(feature)
141+
|| featureGroups.allDerivedFeatures.contains(feature))
142+
!containsFeature.contains(false)
143+
}) ++ updatedSeqJoinFeature
144+
145+
log.warn(s"Removed the following features:- ${featureGroups.allAnchoredFeatures.keySet.diff(updatedAnchoredFeatures.keySet)}," +
146+
s"${featureGroups.allDerivedFeatures.keySet.diff(updatedDerivedFeatures.keySet)}," +
147+
s" ${featureGroups.allSeqJoinFeatures.keySet.diff(updatedSeqJoinFeature.keySet)}")
148+
val updatedFeatureGroups = FeatureGroups(updatedAnchoredFeatures, updatedDerivedFeatures, featureGroups.allWindowAggFeatures,
149+
featureGroups.allPassthroughFeatures, updatedSeqJoinFeature)
150+
val updatedKeyTaggedFeatures = keyTaggedFeatures.filter(feature => updatedAnchoredFeatures.contains(feature.featureName)
151+
|| updatedDerivedFeatures.contains(feature.featureName) || featureGroups.allWindowAggFeatures.contains(feature.featureName)
152+
|| featureGroups.allPassthroughFeatures.contains(feature.featureName) || updatedSeqJoinFeature.contains(feature.featureName))
153+
(updatedFeatureGroups, updatedKeyTaggedFeatures)
154+
}
155+
113156
/**
114157
* Exclude anchored and derived features features from the join stage if they do not have a valid path.
115158
* @param featureToPathsMap Map of anchored feature names to their paths

feathr-impl/src/main/scala/com/linkedin/feathr/offline/join/DataFrameFeatureJoiner.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ import com.linkedin.feathr.offline
55
import com.linkedin.feathr.offline.client.DataFrameColName
66
import com.linkedin.feathr.offline.client.DataFrameColName.getFeatureAlias
77
import com.linkedin.feathr.offline.config.FeatureJoinConfig
8+
import com.linkedin.feathr.offline.config.sources.FeatureGroupsUpdater
89
import com.linkedin.feathr.offline.derived.DerivedFeatureEvaluator
910
import com.linkedin.feathr.offline.job.FeatureTransformation.transformSingleAnchorDF
10-
import com.linkedin.feathr.offline.job.{FeatureTransformation, TransformedResult}
11+
import com.linkedin.feathr.offline.job.{FeatureTransformation, LocalFeatureJoinJob, TransformedResult}
1112
import com.linkedin.feathr.offline.join.algorithms._
1213
import com.linkedin.feathr.offline.join.util.{FrequentItemEstimatorFactory, FrequentItemEstimatorType}
1314
import com.linkedin.feathr.offline.join.workflow._
14-
import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlan}
15+
import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlan, MultiStageJoinPlanner}
1516
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
1617
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
1718
import com.linkedin.feathr.offline.swa.SlidingWindowAggregationJoiner
@@ -22,6 +23,7 @@ import com.linkedin.feathr.offline.util.FeathrUtils
2223
import com.linkedin.feathr.offline.util.datetime.DateTimeInterval
2324
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
2425
import org.apache.logging.log4j.LogManager
26+
import org.apache.spark.sql.internal.SQLConf
2527
import org.apache.spark.sql.{DataFrame, SparkSession}
2628
import org.apache.spark.util.sketch.BloomFilter
2729

@@ -189,12 +191,21 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
189191
.toIndexedSeq
190192
.map(featureGroups.allAnchoredFeatures),
191193
failOnMissingPartition)
192-
194+
val shouldSkipFeature = (FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.SKIP_MISSING_FEATURE).toBoolean) ||
195+
(ss.sparkContext.isLocal && SQLConf.get.getConf(LocalFeatureJoinJob.SKIP_MISSING_FEATURE))
193196
val updatedSourceAccessorMap = anchorSourceAccessorMap.filter(anchorEntry => anchorEntry._2.isDefined)
194197
.map(anchorEntry => anchorEntry._1 -> anchorEntry._2.get)
195198

199+
val (updatedFeatureGroups, updatedLogicalPlan) = if (shouldSkipFeature) {
200+
val (newFeatureGroups, newKeyTaggedFeatures) = FeatureGroupsUpdater().removeMissingFeatures(featureGroups,
201+
updatedSourceAccessorMap.keySet.flatMap(featureAnchorWithSource => featureAnchorWithSource.featureAnchor.features).toSeq, keyTaggedFeatures)
202+
203+
val newLogicalPlan = MultiStageJoinPlanner().getLogicalPlan(newFeatureGroups, newKeyTaggedFeatures)
204+
(newFeatureGroups, newLogicalPlan)
205+
} else (featureGroups, logicalPlan)
206+
196207
implicit val joinExecutionContext: JoinExecutionContext =
197-
JoinExecutionContext(ss, logicalPlan, featureGroups, bloomFilters, Some(saltedJoinFrequentItemDFs))
208+
JoinExecutionContext(ss, updatedLogicalPlan, updatedFeatureGroups, bloomFilters, Some(saltedJoinFrequentItemDFs))
198209
// 3. Join sliding window aggregation features
199210
val FeatureDataFrame(withWindowAggFeatureDF, inferredSWAFeatureTypes) =
200211
joinSWAFeatures(ss, obsToJoinWithFeatures, joinConfig, featureGroups, failOnMissingPartition, bloomFilters, swaObsTime)

feathr-impl/src/main/scala/com/linkedin/feathr/offline/transformation/AnchorToDataSourceMapper.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ import com.linkedin.feathr.offline.config.location.{DataLocation, PathList, Simp
1010
import com.linkedin.feathr.offline.generation.IncrementalAggContext
1111
import com.linkedin.feathr.offline.job.LocalFeatureJoinJob
1212
import com.linkedin.feathr.offline.source.DataSource
13-
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
14-
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
13+
import com.linkedin.feathr.offline.source.accessor.{DataPathHandler, DataSourceAccessor, NonTimeBasedDataSourceAccessor}
1514
import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler
1615
import com.linkedin.feathr.offline.source.pathutil.{PathChecker, TimeBasedHdfsPathAnalyzer}
1716
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils
@@ -70,12 +69,23 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
7069
}
7170
}
7271
val timeSeriesSource = try {
73-
Some(DataSourceAccessor(ss = ss,
72+
val dataSource = DataSourceAccessor(ss = ss,
7473
source = source,
7574
dateIntervalOpt = dateInterval,
7675
expectDatumType = Some(expectDatumType),
7776
failOnMissingPartition = failOnMissingPartition,
78-
dataPathHandlers = dataPathHandlers))
77+
dataPathHandlers = dataPathHandlers)
78+
79+
// If it is a nonTime based source, we will load the dataframe at runtime, however this is too late to decide if the
80+
// feature should be skipped. So, we will try to take the first row here, and see if it succeeds.
81+
if (dataSource.isInstanceOf[NonTimeBasedDataSourceAccessor] && (shouldSkipFeature || (ss.sparkContext.isLocal &&
82+
SQLConf.get.getConf(LocalFeatureJoinJob.SKIP_MISSING_FEATURE)))) {
83+
if (dataSource.get().take(1).isEmpty) None else {
84+
Some(dataSource)
85+
}
86+
} else {
87+
Some(dataSource)
88+
}
7989
} catch {
8090
case e: Exception => if (shouldSkipFeature || (ss.sparkContext.isLocal &&
8191
SQLConf.get.getConf(LocalFeatureJoinJob.SKIP_MISSING_FEATURE))) None else throw e
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
a_id
2-
1
3-
2
4-
3
1+
a_id,timestamp
2+
1,2019-05-20
3+
2,2019-05-19
4+
3,2019-05-19

feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import com.linkedin.feathr.common.configObj.configbuilder.ConfigBuilderException
44
import com.linkedin.feathr.common.exception.FeathrConfigException
55
import com.linkedin.feathr.offline.config.location.SimplePath
66
import com.linkedin.feathr.offline.generation.SparkIOUtils
7-
import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager
7+
import com.linkedin.feathr.offline.job.{LocalFeatureJoinJob, PreprocessedDataFrameManager}
88
import com.linkedin.feathr.offline.source.dataloader.{AvroJsonDataLoader, CsvDataLoader}
99
import com.linkedin.feathr.offline.util.FeathrTestUtils
1010
import org.apache.spark.sql.Row
1111
import org.apache.spark.sql.functions.col
12+
import org.apache.spark.sql.internal.SQLConf
1213
import org.apache.spark.sql.types._
1314
import org.testng.Assert.assertTrue
1415
import org.testng.annotations.{BeforeClass, Test}
@@ -278,6 +279,113 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
278279
// FeathrTestUtils.assertDataFrameApproximatelyEquals(filteredDf, expectedDf, cmpFunc)
279280
}
280281

282+
/*
283+
* Test skipping combination of anchored, derived and swa features.
284+
*/
285+
@Test
286+
def testSkipAnchoredFeatures: Unit = {
287+
SQLConf.get.setConf(LocalFeatureJoinJob.SKIP_MISSING_FEATURE, true)
288+
val df = runLocalFeatureJoinForTest(
289+
joinConfigAsString =
290+
"""
291+
|settings: {
292+
| joinTimeSettings: {
293+
| timestampColumn: {
294+
| def: "timestamp"
295+
| format: "yyyy-MM-dd"
296+
| }
297+
| simulateTimeDelay: 1d
298+
| }
299+
|}
300+
|
301+
| features: {
302+
| key: a_id
303+
| featureList: ["featureWithNull", "derived_featureWithNull", "featureWithNull2", "derived_featureWithNull2",
304+
| "aEmbedding", "memberEmbeddingAutoTZ"]
305+
| }
306+
""".stripMargin,
307+
featureDefAsString =
308+
"""
309+
| sources: {
310+
| swaSource: {
311+
| location: { path: "generaion/daily" }
312+
| timePartitionPattern: "yyyy/MM/dd"
313+
| timeWindowParameters: {
314+
| timestampColumn: "timestamp"
315+
| timestampColumnFormat: "yyyy-MM-dd"
316+
| }
317+
| }
318+
| swaSource1: {
319+
| location: { path: "generation/daily" }
320+
| timePartitionPattern: "yyyy/MM/dd"
321+
| timeWindowParameters: {
322+
| timestampColumn: "timestamp"
323+
| timestampColumnFormat: "yyyy-MM-dd"
324+
| }
325+
| }
326+
|}
327+
|
328+
| anchors: {
329+
| anchor1: {
330+
| source: "anchorAndDerivations/nullVaueSource.avro.json"
331+
| key: "toUpperCaseExt(mId)"
332+
| features: {
333+
| featureWithNull: "isPresent(value) ? toNumeric(value) : 0"
334+
| }
335+
| }
336+
| anchor2: {
337+
| source: "anchorAndDerivations/nullValueSource.avro.json"
338+
| key: "toUpperCaseExt(mId)"
339+
| features: {
340+
| featureWithNull2: "isPresent(value) ? toNumeric(value) : 0"
341+
| }
342+
| }
343+
| swaAnchor: {
344+
| source: "swaSource"
345+
| key: "x"
346+
| features: {
347+
| aEmbedding: {
348+
| def: "embedding"
349+
| aggregation: LATEST
350+
| window: 3d
351+
| }
352+
| }
353+
| }
354+
| swaAnchor1: {
355+
| source: "swaSource1"
356+
| key: "x"
357+
| features: {
358+
| memberEmbeddingAutoTZ: {
359+
| def: "embedding"
360+
| aggregation: LATEST
361+
| window: 3d
362+
| type: {
363+
| type: TENSOR
364+
| tensorCategory: SPARSE
365+
| dimensionType: [INT]
366+
| valType: FLOAT
367+
| }
368+
| }
369+
| }
370+
| }
371+
|}
372+
|derivations: {
373+
|
374+
| derived_featureWithNull: "featureWithNull * 2"
375+
| derived_featureWithNull2: "featureWithNull2 * 2"
376+
|}
377+
""".stripMargin,
378+
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv")
379+
380+
assertTrue(!df.data.columns.contains("featureWithNull"))
381+
assertTrue(!df.data.columns.contains("derived_featureWithNull"))
382+
assertTrue(df.data.columns.contains("derived_featureWithNull2"))
383+
assertTrue(df.data.columns.contains("featureWithNull2"))
384+
assertTrue(!df.data.columns.contains("aEmbedding"))
385+
assertTrue(df.data.columns.contains("memberEmbeddingAutoTZ"))
386+
SQLConf.get.setConf(LocalFeatureJoinJob.SKIP_MISSING_FEATURE, false)
387+
}
388+
281389
/*
282390
* Test features with null values.
283391
*/

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
version=0.10.4-rc6
1+
version=0.10.4-rc7
22
SONATYPE_AUTOMATIC_RELEASE=true
33
POM_ARTIFACT_ID=feathr_2.12

0 commit comments

Comments
 (0)