Skip to content

Inject SDK-side flattens while handling input/output coder mismatch in flattens. #34641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ def sickbayTests = [
// ShardedKey not yet implemented.
'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow',

// Java side dying during execution.
// Stream corruption error java side: failed:java.io.StreamCorruptedException: invalid stream header: 206E6F74
// Likely due to prism't coder changes.
'org.apache.beam.sdk.transforms.FlattenTest.testFlattenWithDifferentInputAndOutputCoders2',

// java.lang.IllegalStateException: Output with tag Tag<output> must have a schema in order to call getRowReceiver
// Ultimately because getRoeReceiver code path SDK side isn't friendly to LengthPrefix wrapping of row coders.
// https://github.com/apache/beam/issues/32931
Expand Down
67 changes: 31 additions & 36 deletions sdks/go/pkg/beam/runners/prism/internal/handlerunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ func (h *runner) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipep
}

func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.Components) prepareResult {
if !h.config.SDKFlatten {
t.EnvironmentId = "" // force the flatten to be a runner transform due to configuration.
if !h.config.SDKFlatten && !strings.HasPrefix(tid, "ft_") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note that there's no user serviceable way to do these configurations at the moment, and it really was a hard binary. It would be acceptable to remove the SDKFlatten option in favour of just a single approach that biases to runner flattens, but does the SDK flatten to get around these issues.

forcedRoots := []string{tid} // Have runner side transforms be roots.

// Force runner flatten consumers to be roots.
Expand All @@ -109,52 +108,48 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C
// they're written out to the runner in the same fashion.
// This may stop being necessary once Flatten Unzipping happens in the optimizer.
outPCol := comps.GetPcollections()[outColID]
outCoderID := outPCol.CoderId
outCoder := comps.GetCoders()[outCoderID]
coderSubs := map[string]*pipepb.Coder{}
pcollSubs := map[string]*pipepb.PCollection{}
tSubs := map[string]*pipepb.PTransform{}

if !strings.HasPrefix(outCoderID, "cf_") {
// Create a new coder id for the flatten output PCollection and use
// this coder id for all input PCollections
outCoderID = "cf_" + outColID
outCoder = proto.Clone(outCoder).(*pipepb.Coder)
coderSubs[outCoderID] = outCoder

pcollSubs[outColID] = proto.Clone(outPCol).(*pipepb.PCollection)
pcollSubs[outColID].CoderId = outCoderID

outPCol = pcollSubs[outColID]
}

for _, p := range t.GetInputs() {
ts := proto.Clone(t).(*pipepb.PTransform)
ts.EnvironmentId = "" // force the flatten to be a runner transform due to configuration.
for localID, p := range t.GetInputs() {
inPCol := comps.GetPcollections()[p]
if inPCol.CoderId != outPCol.CoderId {
if strings.HasPrefix(inPCol.CoderId, "cf_") {
// The input pcollection is the output of another flatten:
// e.g. [[a, b] | Flatten], c] | Flatten
// In this case, we just point the input coder id to the new flatten
// output coder, so any upstream input pcollections will use the new
// output coder.
coderSubs[inPCol.CoderId] = outCoder
} else {
// Create a substitute PCollection for this input with the flatten
// output coder id
pcollSubs[p] = proto.Clone(inPCol).(*pipepb.PCollection)
pcollSubs[p].CoderId = outPCol.CoderId
}
// TODO: do the following injection conditionally.
// Now we inject an SDK-side flatten between the upstream transform and
// the flatten.
// Before: upstream -> [upstream out] -> runner flatten
// After: upstream -> [upstream out] -> SDK-side flatten -> [SDK-side flatten out] -> runner flatten
// Create a PCollection sub
fColID := "fc_" + p + "_to_" + outColID
fPCol := proto.Clone(outPCol).(*pipepb.PCollection)
fPCol.CoderId = outPCol.CoderId // same coder as runner flatten
pcollSubs[fColID] = fPCol

// Create a PTransform sub
ftID := "ft_" + p + "_to_" + outColID
ft := proto.Clone(t).(*pipepb.PTransform)
ft.EnvironmentId = t.EnvironmentId // Set environment to ensure it is a SDK-side transform
ft.Inputs = map[string]string{"0": p}
ft.Outputs = map[string]string{"0": fColID}
tSubs[ftID] = ft

// Replace the input of runner flatten with the output of SDK-side flatten
ts.Inputs[localID] = fColID

// Force sdk-side flattens to be roots
forcedRoots = append(forcedRoots, ftID)
}
}
tSubs[tid] = ts

// Return the new components which is the transforms consumer
return prepareResult{
// We sub this flatten with itself, to not drop it.
SubbedComps: &pipepb.Components{
Transforms: map[string]*pipepb.PTransform{
tid: t,
},
Transforms: tSubs,
Pcollections: pcollSubs,
Coders: coderSubs,
},
RemovedLeaves: nil,
ForcedRoots: forcedRoots,
Expand Down
3 changes: 3 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
}

stg.internalCols = internal
// Sort the keys of internal producers (from stageFacts.PcolProducers)
// to ensure deterministic order for stable tests.
sort.Strings(stg.internalCols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find! I thought I had everything deterministic already.

stg.outputs = maps.Values(outputs)
stg.sideInputs = sideInputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,16 @@ def test_sql(self):

def test_flattened_side_input(self):
# Blocked on support for transcoding
# https://jira.apache.org/jira/browse/BEAM-6523
# https://github.com/apache/beam/issues/19365
super().test_flattened_side_input(with_transcoding=False)

def test_flatten_and_gbk(self):
# Blocked on support for transcoding
# https://github.com/apache/beam/issues/19365
# Also blocked on support of flatten and groupby sharing the same input
# https://github.com/apache/beam/issues/34647
raise unittest.SkipTest("https://github.com/apache/beam/issues/34647")

def test_metrics(self):
super().test_metrics(check_gauge=False, check_bounded_trie=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,22 @@ def test_flattened_side_input(self, with_transcoding=True):
equal_to([('a', 1), ('b', 2)] + third_element),
label='CheckFlattenOfSideInput')

def test_flatten_and_gbk(self, with_transcoding=True):
with self.create_pipeline() as p:
side1 = p | 'side1' >> beam.Create([('a', 1)])
if with_transcoding:
# Also test non-matching coder types (transcoding required)
second_element = [('another_type')]
else:
second_element = [('b', 2)]
side2 = p | 'side2' >> beam.Create(second_element)

flatten_out = (side1, side2) | beam.Flatten()
gbk_out = side1 | beam.GroupByKey()

assert_that(flatten_out, equal_to([('a', 1)] + second_element))
assert_that(gbk_out, equal_to([('a', [1])]))

def test_gbk_side_input(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create([None])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def test_flattened_side_input(self):
# https://github.com/apache/beam/issues/20984
super().test_flattened_side_input(with_transcoding=False)

def test_flatten_and_gbk(self):
# Blocked on support for transcoding
# https://github.com/apache/beam/issues/20984
super().test_flatten_and_gbk(with_transcoding=False)

def test_pack_combiners(self):
# Stages produced by translations.pack_combiners are fused
# by translations.greedily_fuse, which prevent the stages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,14 @@ def test_pardo_dynamic_timer(self):

def test_flattened_side_input(self):
# Blocked on support for transcoding
# https://jira.apache.org/jira/browse/BEAM-7236
# https://github.com/apache/beam/issues/19504
super().test_flattened_side_input(with_transcoding=False)

def test_flatten_and_gbk(self):
# Blocked on support for transcoding
# https://github.com/apache/beam/issues/19504
super().test_flatten_and_gbk(with_transcoding=False)

def test_custom_merging_window(self):
raise unittest.SkipTest("https://github.com/apache/beam/issues/20641")

Expand Down
Loading