Skip to content

Commit c1650e9

Browse files
Normalize tf record io (#34411)
* create TFRecordReadSchemaTransform * create TFRecordWriteSchemaTransform * create TFRecordSchemaTransform Test * fix conflicts * add writeToTFRecord return * add translation file and getRowConfiguration method for said file * fix formatting * remove print * fix more lint issues * add support for bytes in yaml * update ReadTransform with error handling and use string compression * add error handling to writetransform and change compression to string * change compression to string * update compression type and no_spilling * update writeToTFRecord parameters * add tfrecord yaml test pipeline * remove old code * fix lint issue * update standard external transforms with tfrecord info * remove bad character and broken comment * fix lint issue * add no_spilling doc string * change tfrecord.yaml to write version * update parameter name * update read and write yaml for tfrecord * update pipeline to handle write and read * fix lint issues * fix lint * fix lint, precommit checker is broken for me, so hence many commits :) * minor fix * fix java comments * Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java Co-authored-by: Ahmed Abualsaud <[email protected]> * Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordSchemaTransformTranslation.java Co-authored-by: Ahmed Abualsaud <[email protected]> * fix class name change * fix MakeItWork test case * fix spotless issues * Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java Co-authored-by: Ahmed Abualsaud <[email protected]> * remove nullable for outputprefix per comments * remove no_spilling parameter and extra white space * update standard external transforms * remove one more no_spilling * fix nullable on no_spilling * rerun generate external transforms * fix order of python and java providers * revert java and python providers changes * update test and fix prior test failures --------- Co-authored-by: Ahmed Abualsaud <[email protected]>
1 parent 3af0398 commit c1650e9

12 files changed

+1575
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.io;
19+
20+
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
21+
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
22+
23+
import com.google.auto.value.AutoValue;
24+
import java.io.IOException;
25+
import java.io.Serializable;
26+
import javax.annotation.Nullable;
27+
import org.apache.beam.sdk.io.fs.MatchResult;
28+
import org.apache.beam.sdk.schemas.AutoValueSchema;
29+
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
30+
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
31+
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
32+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
33+
34+
/**
35+
* Configuration for reading from TFRecord.
36+
*
37+
* <p>This class is meant to be used with {@link TFRecordReadSchemaTransformProvider}.
38+
*
39+
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We
40+
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam
41+
* repository.
42+
*/
43+
@DefaultSchema(AutoValueSchema.class)
44+
@AutoValue
45+
public abstract class TFRecordReadSchemaTransformConfiguration implements Serializable {
46+
47+
public void validate() {
48+
String invalidConfigMessage = "Invalid TFRecord Read configuration: ";
49+
50+
if (getValidate()) {
51+
String filePattern = getFilePattern();
52+
try {
53+
MatchResult matches = FileSystems.match(filePattern);
54+
checkState(
55+
!matches.metadata().isEmpty(), "Unable to find any files matching %s", filePattern);
56+
} catch (IOException e) {
57+
throw new IllegalStateException(
58+
String.format(invalidConfigMessage + "Failed to validate %s", filePattern), e);
59+
}
60+
}
61+
62+
ErrorHandling errorHandling = getErrorHandling();
63+
if (errorHandling != null) {
64+
checkArgument(
65+
!Strings.isNullOrEmpty(errorHandling.getOutput()),
66+
invalidConfigMessage + "Output must not be empty if error handling specified.");
67+
}
68+
}
69+
70+
/** Instantiates a {@link TFRecordReadSchemaTransformConfiguration.Builder} instance. */
71+
public static TFRecordReadSchemaTransformConfiguration.Builder builder() {
72+
return new AutoValue_TFRecordReadSchemaTransformConfiguration.Builder();
73+
}
74+
75+
@SchemaFieldDescription("Validate file pattern.")
76+
public abstract boolean getValidate();
77+
78+
@SchemaFieldDescription("Decompression type to use when reading input files.")
79+
public abstract String getCompression();
80+
81+
@SchemaFieldDescription("Filename or file pattern used to find input files.")
82+
public abstract String getFilePattern();
83+
84+
@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
85+
public abstract @Nullable ErrorHandling getErrorHandling();
86+
87+
abstract Builder toBuilder();
88+
89+
/** Builder for {@link TFRecordReadSchemaTransformConfiguration}. */
90+
@AutoValue.Builder
91+
public abstract static class Builder {
92+
93+
public abstract Builder setValidate(boolean value);
94+
95+
public abstract Builder setCompression(String value);
96+
97+
public abstract Builder setFilePattern(String value);
98+
99+
public abstract Builder setErrorHandling(@Nullable ErrorHandling errorHandling);
100+
101+
/** Builds the {@link TFRecordReadSchemaTransformConfiguration} configuration. */
102+
public abstract TFRecordReadSchemaTransformConfiguration build();
103+
}
104+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.io;
19+
20+
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
21+
22+
import com.google.auto.service.AutoService;
23+
import java.util.Arrays;
24+
import java.util.List;
25+
import org.apache.beam.sdk.metrics.Counter;
26+
import org.apache.beam.sdk.metrics.Metrics;
27+
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
28+
import org.apache.beam.sdk.schemas.Schema;
29+
import org.apache.beam.sdk.schemas.SchemaRegistry;
30+
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
31+
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
32+
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
33+
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
34+
import org.apache.beam.sdk.transforms.DoFn;
35+
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
36+
import org.apache.beam.sdk.transforms.ParDo;
37+
import org.apache.beam.sdk.transforms.SerializableFunction;
38+
import org.apache.beam.sdk.transforms.SimpleFunction;
39+
import org.apache.beam.sdk.values.PCollection;
40+
import org.apache.beam.sdk.values.PCollectionRowTuple;
41+
import org.apache.beam.sdk.values.PCollectionTuple;
42+
import org.apache.beam.sdk.values.Row;
43+
import org.apache.beam.sdk.values.TupleTag;
44+
import org.apache.beam.sdk.values.TupleTagList;
45+
import org.slf4j.Logger;
46+
import org.slf4j.LoggerFactory;
47+
48+
@AutoService(SchemaTransformProvider.class)
49+
public class TFRecordReadSchemaTransformProvider
50+
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> {
51+
private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1";
52+
private static final String OUTPUT = "output";
53+
private static final String ERROR = "errors";
54+
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
55+
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
56+
private static final Logger LOG =
57+
LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class);
58+
59+
/** Returns the expected {@link SchemaTransform} of the configuration. */
60+
@Override
61+
protected SchemaTransform from(TFRecordReadSchemaTransformConfiguration configuration) {
62+
return new TFRecordReadSchemaTransform(configuration);
63+
}
64+
65+
/** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */
66+
@Override
67+
public String identifier() {
68+
return IDENTIFIER;
69+
}
70+
71+
/** Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. */
72+
@Override
73+
public List<String> outputCollectionNames() {
74+
return Arrays.asList(OUTPUT, ERROR);
75+
}
76+
77+
/**
78+
* An implementation of {@link SchemaTransform} for TFRecord read jobs configured using {@link
79+
* TFRecordReadSchemaTransformConfiguration}.
80+
*/
81+
static class TFRecordReadSchemaTransform extends SchemaTransform {
82+
private final TFRecordReadSchemaTransformConfiguration configuration;
83+
84+
TFRecordReadSchemaTransform(TFRecordReadSchemaTransformConfiguration configuration) {
85+
this.configuration = configuration;
86+
}
87+
88+
public Row getConfigurationRow() {
89+
try {
90+
// To stay consistent with our SchemaTransform configuration naming conventions,
91+
// we sort lexicographically
92+
return SchemaRegistry.createDefault()
93+
.getToRowFunction(TFRecordReadSchemaTransformConfiguration.class)
94+
.apply(configuration)
95+
.sorted()
96+
.toSnakeCase();
97+
} catch (NoSuchSchemaException e) {
98+
throw new RuntimeException(e);
99+
}
100+
}
101+
102+
@Override
103+
public PCollectionRowTuple expand(PCollectionRowTuple input) {
104+
// Validate configuration parameters
105+
configuration.validate();
106+
107+
TFRecordIO.Read readTransform =
108+
TFRecordIO.read().withCompression(Compression.valueOf(configuration.getCompression()));
109+
110+
String filePattern = configuration.getFilePattern();
111+
if (filePattern != null) {
112+
readTransform = readTransform.from(filePattern);
113+
}
114+
if (!configuration.getValidate()) {
115+
readTransform = readTransform.withoutValidation();
116+
}
117+
118+
// Read TFRecord files into a PCollection of byte arrays.
119+
PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform);
120+
121+
// Define the schema for the row
122+
final Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES));
123+
Schema errorSchema = ErrorHandling.errorSchemaBytes();
124+
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());
125+
126+
SerializableFunction<byte[], Row> bytesToRowFn = getBytesToRowFn(schema);
127+
128+
// Apply bytes to row fn
129+
PCollectionTuple outputTuple =
130+
tfRecordValues.apply(
131+
ParDo.of(
132+
new ErrorFn(
133+
"TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors))
134+
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
135+
136+
PCollectionRowTuple outputRows =
137+
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(schema));
138+
139+
// Error handling
140+
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
141+
if (handleErrors) {
142+
outputRows =
143+
outputRows.and(
144+
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput);
145+
}
146+
return outputRows;
147+
}
148+
}
149+
150+
public static SerializableFunction<byte[], Row> getBytesToRowFn(Schema schema) {
151+
return new SimpleFunction<byte[], Row>() {
152+
@Override
153+
public Row apply(byte[] input) {
154+
return Row.withSchema(schema).addValues(input).build();
155+
}
156+
};
157+
}
158+
159+
public static class ErrorFn extends DoFn<byte[], Row> {
160+
private final SerializableFunction<byte[], Row> valueMapper;
161+
private final Counter errorCounter;
162+
private Long errorsInBundle = 0L;
163+
private final boolean handleErrors;
164+
private final Schema errorSchema;
165+
166+
public ErrorFn(
167+
String name,
168+
SerializableFunction<byte[], Row> valueMapper,
169+
Schema errorSchema,
170+
boolean handleErrors) {
171+
this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name);
172+
this.valueMapper = valueMapper;
173+
this.handleErrors = handleErrors;
174+
this.errorSchema = errorSchema;
175+
}
176+
177+
@ProcessElement
178+
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) {
179+
Row mappedRow = null;
180+
try {
181+
mappedRow = valueMapper.apply(msg);
182+
} catch (Exception e) {
183+
if (!handleErrors) {
184+
throw new RuntimeException(e);
185+
}
186+
errorsInBundle += 1;
187+
LOG.warn("Error while parsing the element", e);
188+
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e));
189+
}
190+
if (mappedRow != null) {
191+
receiver.get(OUTPUT_TAG).output(mappedRow);
192+
}
193+
}
194+
195+
@FinishBundle
196+
public void finish(FinishBundleContext c) {
197+
errorCounter.inc(errorsInBundle);
198+
errorsInBundle = 0L;
199+
}
200+
}
201+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.io;
19+
20+
import static org.apache.beam.sdk.io.TFRecordReadSchemaTransformProvider.TFRecordReadSchemaTransform;
21+
import static org.apache.beam.sdk.io.TFRecordWriteSchemaTransformProvider.TFRecordWriteSchemaTransform;
22+
23+
import com.google.auto.service.AutoService;
24+
import java.util.Map;
25+
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
26+
import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation;
27+
import org.apache.beam.sdk.transforms.PTransform;
28+
import org.apache.beam.sdk.util.construction.PTransformTranslation;
29+
import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar;
30+
import org.apache.beam.sdk.values.Row;
31+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
32+
33+
public class TFRecordSchemaTransformTranslation {
34+
public static class TFRecordReadSchemaTransformTranslator
35+
extends SchemaTransformTranslation.SchemaTransformPayloadTranslator<
36+
TFRecordReadSchemaTransform> {
37+
@Override
38+
public SchemaTransformProvider provider() {
39+
return new TFRecordReadSchemaTransformProvider();
40+
}
41+
42+
@Override
43+
public Row toConfigRow(TFRecordReadSchemaTransform transform) {
44+
return transform.getConfigurationRow();
45+
}
46+
}
47+
48+
public static class TFRecordWriteSchemaTransformTranslator
49+
extends SchemaTransformTranslation.SchemaTransformPayloadTranslator<
50+
TFRecordWriteSchemaTransform> {
51+
@Override
52+
public SchemaTransformProvider provider() {
53+
return new TFRecordWriteSchemaTransformProvider();
54+
}
55+
56+
@Override
57+
public Row toConfigRow(TFRecordWriteSchemaTransform transform) {
58+
return transform.getConfigurationRow();
59+
}
60+
}
61+
62+
@AutoService(TransformPayloadTranslatorRegistrar.class)
63+
public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegistrar {
64+
@Override
65+
@SuppressWarnings({
66+
"rawtypes",
67+
})
68+
public Map<
69+
? extends Class<? extends PTransform>,
70+
? extends PTransformTranslation.TransformPayloadTranslator>
71+
getTransformPayloadTranslators() {
72+
return ImmutableMap
73+
.<Class<? extends PTransform>, PTransformTranslation.TransformPayloadTranslator>builder()
74+
.put(TFRecordReadSchemaTransform.class, new TFRecordReadSchemaTransformTranslator())
75+
.put(TFRecordWriteSchemaTransform.class, new TFRecordWriteSchemaTransformTranslator())
76+
.build();
77+
}
78+
}
79+
}

0 commit comments

Comments
 (0)