Skip to content

Commit a994291

Browse files
authored
Add explicit schema support to JdbcIO read and xlang transform. (#34128)
* Add ability to set schema in JdbcIO.java and jdbc.py Read. * Add tests. * Run postcommit. * Revert "Run postcommit." This reverts commit daeacfc.
1 parent 5ec0407 commit a994291

File tree

6 files changed

+510
-105
lines changed

6 files changed

+510
-105
lines changed

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java

+31-3
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,9 @@ public abstract static class ReadRows extends PTransform<PBegin, PCollection<Row
747747
@Pure
748748
abstract boolean getDisableAutoCommit();
749749

750+
@Pure
751+
abstract @Nullable Schema getSchema();
752+
750753
abstract Builder toBuilder();
751754

752755
@AutoValue.Builder
@@ -764,6 +767,8 @@ abstract Builder setDataSourceProviderFn(
764767

765768
abstract Builder setDisableAutoCommit(boolean disableAutoCommit);
766769

770+
abstract Builder setSchema(@Nullable Schema schema);
771+
767772
abstract ReadRows build();
768773
}
769774

@@ -791,6 +796,10 @@ public ReadRows withStatementPreparator(StatementPreparator statementPreparator)
791796
return toBuilder().setStatementPreparator(statementPreparator).build();
792797
}
793798

799+
public ReadRows withSchema(Schema schema) {
800+
return toBuilder().setSchema(schema).build();
801+
}
802+
794803
/**
795804
* This method is used to set the size of the data that is going to be fetched and loaded in
796805
* memory per every database call. Please refer to: {@link java.sql.Statement#setFetchSize(int)}
@@ -832,7 +841,14 @@ public PCollection<Row> expand(PBegin input) {
832841
getDataSourceProviderFn(),
833842
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
834843

835-
Schema schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
844+
// Don't infer schema if explicitly provided.
845+
Schema schema;
846+
if (getSchema() != null) {
847+
schema = getSchema();
848+
} else {
849+
schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
850+
}
851+
836852
PCollection<Row> rows =
837853
input.apply(
838854
JdbcIO.<Row>read()
@@ -1294,6 +1310,9 @@ public abstract static class ReadWithPartitions<T, PartitionColumnT>
12941310
@Pure
12951311
abstract boolean getUseBeamSchema();
12961312

1313+
@Pure
1314+
abstract @Nullable Schema getSchema();
1315+
12971316
@Pure
12981317
abstract @Nullable PartitionColumnT getLowerBound();
12991318

@@ -1335,6 +1354,8 @@ abstract Builder<T, PartitionColumnT> setDataSourceProviderFn(
13351354

13361355
abstract Builder<T, PartitionColumnT> setUseBeamSchema(boolean useBeamSchema);
13371356

1357+
abstract Builder setSchema(@Nullable Schema schema);
1358+
13381359
abstract Builder<T, PartitionColumnT> setFetchSize(int fetchSize);
13391360

13401361
abstract Builder<T, PartitionColumnT> setTable(String tableName);
@@ -1426,6 +1447,10 @@ public ReadWithPartitions<T, PartitionColumnT> withTable(String tableName) {
14261447
return toBuilder().setTable(tableName).build();
14271448
}
14281449

1450+
public ReadWithPartitions<T, PartitionColumnT> withSchema(Schema schema) {
1451+
return toBuilder().setSchema(schema).build();
1452+
}
1453+
14291454
private static final int EQUAL = 0;
14301455

14311456
@Override
@@ -1534,8 +1559,11 @@ public KV<Long, KV<PartitionColumnT, PartitionColumnT>> apply(
15341559
Schema schema = null;
15351560
if (getUseBeamSchema()) {
15361561
schema =
1537-
ReadRows.inferBeamSchema(
1538-
dataSourceProviderFn.apply(null), String.format("SELECT * FROM %s", getTable()));
1562+
getSchema() != null
1563+
? getSchema()
1564+
: ReadRows.inferBeamSchema(
1565+
dataSourceProviderFn.apply(null),
1566+
String.format("SELECT * FROM %s", getTable()));
15391567
rowMapper = (RowMapper<T>) SchemaUtil.BeamRowMapper.of(schema);
15401568
} else {
15411569
rowMapper = getRowMapper();

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public Schema configurationSchema() {
8484
*/
8585
@Override
8686
public JdbcSchemaIO from(String location, Row configuration, @Nullable Schema dataSchema) {
87-
return new JdbcSchemaIO(location, configuration);
87+
return new JdbcSchemaIO(location, configuration, dataSchema);
8888
}
8989

9090
@Override
@@ -101,10 +101,12 @@ public PCollection.IsBounded isBounded() {
101101
static class JdbcSchemaIO implements SchemaIO, Serializable {
102102
protected final Row config;
103103
protected final String location;
104+
protected final @Nullable Schema dataSchema;
104105

105-
JdbcSchemaIO(String location, Row config) {
106+
JdbcSchemaIO(String location, Row config, @Nullable Schema dataSchema) {
106107
this.config = config;
107108
this.location = location;
109+
this.dataSchema = dataSchema;
108110
}
109111

110112
@Override
@@ -147,6 +149,10 @@ public PCollection<Row> expand(PBegin input) {
147149
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
148150
}
149151

152+
if (dataSchema != null) {
153+
readRows = readRows.withSchema(dataSchema);
154+
}
155+
150156
return input.apply(readRows);
151157
} else {
152158

@@ -175,6 +181,9 @@ public PCollection<Row> expand(PBegin input) {
175181
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
176182
}
177183

184+
if (dataSchema != null) {
185+
readRows = readRows.withSchema(dataSchema);
186+
}
178187
return input.apply(readRows);
179188
}
180189
}

sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java

+71
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,51 @@ public void testReadWithSchema() {
470470
pipeline.run();
471471
}
472472

473+
@Test
474+
public void testReadRowsWithExplicitSchema() {
475+
Schema customSchema =
476+
Schema.of(
477+
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
478+
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT64).withNullable(true));
479+
480+
PCollection<Row> rows =
481+
pipeline.apply(
482+
JdbcIO.readRows()
483+
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
484+
.withQuery(String.format("select name,id from %s where name = ?", READ_TABLE_NAME))
485+
.withStatementPreparator(
486+
preparedStatement -> preparedStatement.setString(1, TestRow.getNameForSeed(1)))
487+
.withSchema(customSchema));
488+
489+
assertEquals(customSchema, rows.getSchema());
490+
491+
PCollection<Row> output = rows.apply(Select.fieldNames("CUSTOMER_NAME", "CUSTOMER_ID"));
492+
PAssert.that(output)
493+
.containsInAnyOrder(
494+
ImmutableList.of(Row.withSchema(customSchema).addValues("Testval1", 1L).build()));
495+
496+
pipeline.run();
497+
}
498+
499+
@Test
500+
@SuppressWarnings({"UnusedVariable"})
501+
public void testIncompatibleSchemaThrowsError() {
502+
Schema incompatibleSchema =
503+
Schema.of(
504+
Schema.Field.of("WRONG_TYPE_NAME", Schema.FieldType.INT64),
505+
Schema.Field.of("WRONG_TYPE_ID", Schema.FieldType.STRING));
506+
507+
Pipeline pipeline = Pipeline.create();
508+
pipeline.apply(
509+
JdbcIO.readRows()
510+
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
511+
.withQuery(String.format("select name,id from %s limit 10", READ_TABLE_NAME))
512+
.withSchema(incompatibleSchema));
513+
514+
PipelineExecutionException exception =
515+
assertThrows(PipelineExecutionException.class, () -> pipeline.run().waitUntilFinish());
516+
}
517+
473518
@Test
474519
public void testReadWithPartitions() {
475520
PCollection<TestRow> rows =
@@ -486,6 +531,32 @@ public void testReadWithPartitions() {
486531
pipeline.run();
487532
}
488533

534+
@Test
535+
public void testReadWithPartitionsWithExplicitSchema() {
536+
Schema customSchema =
537+
Schema.of(
538+
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
539+
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));
540+
541+
PCollection<Row> rows =
542+
pipeline.apply(
543+
JdbcIO.<Row>readWithPartitions()
544+
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
545+
.withTable(String.format("(select name,id from %s) as subq", READ_TABLE_NAME))
546+
.withNumPartitions(5)
547+
.withPartitionColumn("id")
548+
.withLowerBound(0L)
549+
.withUpperBound(1000L)
550+
.withRowOutput()
551+
.withSchema(customSchema));
552+
553+
assertEquals(customSchema, rows.getSchema());
554+
555+
PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1000L);
556+
557+
pipeline.run();
558+
}
559+
489560
@Test
490561
public void testReadWithPartitionsBySubqery() {
491562
PCollection<TestRow> rows =

sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java

+87
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.beam.sdk.io.jdbc;
1919

2020
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertNotNull;
2122

2223
import java.sql.Connection;
2324
import java.sql.PreparedStatement;
@@ -85,6 +86,92 @@ public void testPartitionedRead() {
8586
pipeline.run();
8687
}
8788

89+
@Test
90+
public void testPartitionedReadWithExplicitSchema() {
91+
JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();
92+
93+
Schema customSchema =
94+
Schema.of(
95+
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
96+
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));
97+
98+
Row config =
99+
Row.withSchema(provider.configurationSchema())
100+
.withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
101+
.withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get())
102+
.withFieldValue("username", "")
103+
.withFieldValue("password", "")
104+
.withFieldValue("partitionColumn", "id")
105+
.withFieldValue("partitions", (short) 10)
106+
.build();
107+
108+
JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
109+
provider.from(
110+
String.format("(select name,id from %s) as subq", READ_TABLE_NAME),
111+
config,
112+
customSchema);
113+
114+
PCollection<Row> output = pipeline.apply(schemaIO.buildReader());
115+
116+
assertEquals(customSchema, output.getSchema());
117+
118+
Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
119+
PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);
120+
121+
PAssert.that(output)
122+
.satisfies(
123+
rows -> {
124+
for (Row row : rows) {
125+
assertNotNull(row.getString("CUSTOMER_NAME"));
126+
assertNotNull(row.getInt32("CUSTOMER_ID"));
127+
}
128+
return null;
129+
});
130+
131+
pipeline.run();
132+
}
133+
134+
@Test
135+
public void testReadWithExplicitSchema() {
136+
JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();
137+
138+
Schema customSchema =
139+
Schema.of(
140+
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
141+
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));
142+
143+
Row config =
144+
Row.withSchema(provider.configurationSchema())
145+
.withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
146+
.withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get())
147+
.withFieldValue("username", "")
148+
.withFieldValue("password", "")
149+
.withFieldValue("readQuery", "SELECT name, id FROM " + READ_TABLE_NAME)
150+
.build();
151+
152+
JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
153+
provider.from(READ_TABLE_NAME, config, customSchema);
154+
155+
PCollection<Row> output = pipeline.apply(schemaIO.buildReader());
156+
157+
assertEquals(customSchema, output.getSchema());
158+
159+
Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
160+
PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);
161+
162+
PAssert.that(output)
163+
.satisfies(
164+
rows -> {
165+
for (Row row : rows) {
166+
assertNotNull(row.getString("CUSTOMER_NAME"));
167+
assertNotNull(row.getInt32("CUSTOMER_ID"));
168+
}
169+
return null;
170+
});
171+
172+
pipeline.run();
173+
}
174+
88175
// This test shouldn't work because we only support numeric and datetime columns and we are trying
89176
// to use a string column as our partition source.
90177
@Test

0 commit comments

Comments
 (0)