Skip to content

Commit 1d273c2

Browse files
robertwbhjtran
authored andcommitted
[YAML] Basic integration testing framework. (apache#29113)
The tests themselves are defined in yaml as a series of pipelines, possibly with some setup code. One of the key features of this framework is that if multiple providers vend the same transform each will be tested to ensure they have consistent behavior.
1 parent 0a52a25 commit 1d273c2

File tree

6 files changed

+358
-2
lines changed

6 files changed

+358
-2
lines changed

sdks/python/apache_beam/testing/util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ def expand(self, pcoll):
301301
if not use_global_window:
302302
plain_actual = plain_actual | 'AddWindow' >> ParDo(AddWindow())
303303

304-
plain_actual = plain_actual | 'Match' >> Map(matcher)
304+
return plain_actual | 'Match' >> Map(matcher)
305305

306306
def default_label(self):
307307
return label
308308

309-
actual | AssertThat() # pylint: disable=expression-not-assigned
309+
return actual | AssertThat()
310310

311311

312312
@ptransform_fn
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""Runs integration tests in the tests directory."""
19+
20+
import contextlib
21+
import copy
22+
import glob
23+
import itertools
24+
import logging
25+
import os
26+
import unittest
27+
import uuid
28+
29+
import mock
30+
import yaml
31+
32+
import apache_beam as beam
33+
from apache_beam.io import filesystems
34+
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
35+
from apache_beam.io.gcp.internal.clients import bigquery
36+
from apache_beam.options.pipeline_options import PipelineOptions
37+
from apache_beam.utils import python_callable
38+
from apache_beam.yaml import yaml_provider
39+
from apache_beam.yaml import yaml_transform
40+
41+
42+
@contextlib.contextmanager
43+
def gcs_temp_dir(bucket):
44+
gcs_tempdir = bucket + '/yaml-' + str(uuid.uuid4())
45+
yield gcs_tempdir
46+
filesystems.FileSystems.delete([gcs_tempdir])
47+
48+
49+
@contextlib.contextmanager
50+
def temp_bigquery_table(project, prefix='yaml_bq_it_'):
51+
bigquery_client = BigQueryWrapper()
52+
dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex)
53+
bigquery_client.get_or_create_dataset(project, dataset_id)
54+
logging.info("Created dataset %s in project %s", dataset_id, project)
55+
yield f'{project}:{dataset_id}.tmp_table'
56+
request = bigquery.BigqueryDatasetsDeleteRequest(
57+
projectId=project, datasetId=dataset_id, deleteContents=True)
58+
logging.info("Deleting dataset %s in project %s", dataset_id, project)
59+
bigquery_client.client.datasets.Delete(request)
60+
61+
62+
def replace_recursive(spec, vars):
63+
if isinstance(spec, dict):
64+
return {
65+
key: replace_recursive(value, vars)
66+
for (key, value) in spec.items()
67+
}
68+
elif isinstance(spec, list):
69+
return [replace_recursive(value, vars) for value in spec]
70+
elif isinstance(spec, str) and '{' in spec:
71+
try:
72+
return spec.format(**vars)
73+
except Exception as exn:
74+
raise ValueError(f"Error evaluating {spec}: {exn}") from exn
75+
else:
76+
return spec
77+
78+
79+
def transform_types(spec):
80+
if spec.get('type', None) in (None, 'composite', 'chain'):
81+
if 'source' in spec:
82+
yield from transform_types(spec['source'])
83+
for t in spec.get('transforms', []):
84+
yield from transform_types(t)
85+
if 'sink' in spec:
86+
yield from transform_types(spec['sink'])
87+
else:
88+
yield spec['type']
89+
90+
91+
def provider_sets(spec, require_available=False):
92+
"""For transforms that are vended by multiple providers, yields all possible
93+
combinations of providers to use.
94+
"""
95+
all_transform_types = set.union(
96+
*(
97+
set(
98+
transform_types(
99+
yaml_transform.preprocess(copy.deepcopy(p['pipeline']))))
100+
for p in spec['pipelines']))
101+
102+
def filter_to_available(t, providers):
103+
if require_available:
104+
for p in providers:
105+
if not p.available():
106+
raise ValueError("Provider {p} required for {t} is not available.")
107+
return providers
108+
else:
109+
return [p for p in providers if p.available()]
110+
111+
standard_providers = yaml_provider.standard_providers()
112+
multiple_providers = {
113+
t: filter_to_available(t, standard_providers[t])
114+
for t in all_transform_types
115+
if len(filter_to_available(t, standard_providers[t])) > 1
116+
}
117+
if not multiple_providers:
118+
return 'only', standard_providers
119+
else:
120+
names, provider_lists = zip(*sorted(multiple_providers.items()))
121+
for ix, c in enumerate(itertools.product(*provider_lists)):
122+
yield (
123+
'_'.join(
124+
n + '_' + type(p.underlying_provider()).__name__
125+
for (n, p) in zip(names, c)) + f'_{ix}',
126+
dict(standard_providers, **{n: [p]
127+
for (n, p) in zip(names, c)}))
128+
129+
130+
def create_test_methods(spec):
131+
for suffix, providers in provider_sets(spec):
132+
133+
def test(self, providers=providers): # default arg to capture loop value
134+
vars = {}
135+
with contextlib.ExitStack() as stack:
136+
stack.enter_context(
137+
mock.patch(
138+
'apache_beam.yaml.yaml_provider.standard_providers',
139+
lambda: providers))
140+
for fixture in spec.get('fixtures', []):
141+
vars[fixture['name']] = stack.enter_context(
142+
python_callable.PythonCallableWithSource.
143+
load_from_fully_qualified_name(fixture['type'])(
144+
**yaml_transform.SafeLineLoader.strip_metadata(
145+
fixture.get('config', {}))))
146+
for pipeline_spec in spec['pipelines']:
147+
with beam.Pipeline(options=PipelineOptions(
148+
pickle_library='cloudpickle',
149+
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
150+
'options', {})))) as p:
151+
yaml_transform.expand_pipeline(
152+
p, replace_recursive(pipeline_spec, vars))
153+
154+
yield f'test_{suffix}', test
155+
156+
157+
def parse_test_files(filepattern):
158+
for path in glob.glob(filepattern):
159+
with open(path) as fin:
160+
suite_name = os.path.splitext(os.path.basename(path))[0].title() + 'Test'
161+
print(path, suite_name)
162+
methods = dict(
163+
create_test_methods(
164+
yaml.load(fin, Loader=yaml_transform.SafeLineLoader)))
165+
globals()[suite_name] = type(suite_name, (unittest.TestCase, ), methods)
166+
167+
168+
logging.getLogger().setLevel(logging.INFO)
169+
parse_test_files(os.path.join(os.path.dirname(__file__), 'tests', '*.yaml'))
170+
171+
if __name__ == '__main__':
172+
logging.getLogger().setLevel(logging.INFO)
173+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the# Row(word='License'); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an# Row(word='AS IS' BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
fixtures:
19+
- name: BQ_TABLE
20+
type: "apache_beam.yaml.integration_tests.temp_bigquery_table"
21+
config:
22+
project: "apache-beam-testing"
23+
- name: TEMP_DIR
24+
# Need distributed filesystem to be able to read and write from a container.
25+
type: "apache_beam.yaml.integration_tests.gcs_temp_dir"
26+
config:
27+
bucket: "gs://temp-storage-for-end-to-end-tests/temp-it"
28+
29+
pipelines:
30+
- pipeline:
31+
type: chain
32+
transforms:
33+
- type: Create
34+
config:
35+
elements:
36+
- {label: "11a", rank: 0}
37+
- {label: "37a", rank: 1}
38+
- {label: "389a", rank: 2}
39+
- type: WriteToBigQuery
40+
config:
41+
table: "{BQ_TABLE}"
42+
options:
43+
project: "apache-beam-testing"
44+
temp_location: "{TEMP_DIR}"
45+
46+
- pipeline:
47+
type: chain
48+
transforms:
49+
- type: ReadFromBigQuery
50+
config:
51+
table: "{BQ_TABLE}"
52+
- type: AssertEqual
53+
config:
54+
elements:
55+
- {label: "11a", rank: 0}
56+
- {label: "37a", rank: 1}
57+
- {label: "389a", rank: 2}
58+
options:
59+
project: "apache-beam-testing"
60+
temp_location: "{TEMP_DIR}"
61+
62+
- pipeline:
63+
type: chain
64+
transforms:
65+
- type: ReadFromBigQuery
66+
config:
67+
table: "{BQ_TABLE}"
68+
fields: ["label"]
69+
row_restriction: "rank > 0"
70+
- type: AssertEqual
71+
config:
72+
elements:
73+
- {label: "37a"}
74+
- {label: "389a"}
75+
options:
76+
project: "apache-beam-testing"
77+
temp_location: "{TEMP_DIR}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the# Row(word='License'); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an# Row(word='AS IS' BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
fixtures:
19+
- name: TEMP_DIR
20+
type: "tempfile.TemporaryDirectory"
21+
22+
pipelines:
23+
- pipeline:
24+
type: chain
25+
transforms:
26+
- type: Create
27+
config:
28+
elements:
29+
- {label: "11a", rank: 0}
30+
- {label: "37a", rank: 1}
31+
- {label: "389a", rank: 2}
32+
- type: WriteToCsv
33+
config:
34+
path: "{TEMP_DIR}/out.csv"
35+
36+
- pipeline:
37+
type: chain
38+
transforms:
39+
- type: ReadFromCsv
40+
config:
41+
path: "{TEMP_DIR}/out.csv*"
42+
- type: AssertEqual
43+
config:
44+
elements:
45+
- {label: "11a", rank: 0}
46+
- {label: "37a", rank: 1}
47+
- {label: "389a", rank: 2}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the# Row(word='License'); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an# Row(word='AS IS' BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
fixtures:
19+
- name: TEMP_DIR
20+
type: "tempfile.TemporaryDirectory"
21+
22+
pipelines:
23+
- pipeline:
24+
type: chain
25+
transforms:
26+
- type: Create
27+
config:
28+
elements:
29+
- {label: "11a", rank: 0}
30+
- {label: "37a", rank: 1}
31+
- {label: "389a", rank: 2}
32+
- type: WriteToJson
33+
config:
34+
path: "{TEMP_DIR}/out.json"
35+
36+
- pipeline:
37+
type: chain
38+
transforms:
39+
- type: ReadFromJson
40+
config:
41+
path: "{TEMP_DIR}/out.json*"
42+
- type: AssertEqual
43+
config:
44+
elements:
45+
- {label: "11a", rank: 0}
46+
- {label: "37a", rank: 1}
47+
- {label: "389a", rank: 2}

sdks/python/apache_beam/yaml/yaml_provider.py

+12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import apache_beam.transforms.util
4848
from apache_beam.portability.api import schema_pb2
4949
from apache_beam.runners import pipeline_context
50+
from apache_beam.testing.util import assert_that
51+
from apache_beam.testing.util import equal_to
5052
from apache_beam.transforms import external
5153
from apache_beam.transforms import window
5254
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
@@ -554,6 +556,15 @@ def dicts_to_rows(o):
554556

555557

556558
class YamlProviders:
559+
class AssertEqual(beam.PTransform):
560+
def __init__(self, elements):
561+
self._elements = elements
562+
563+
def expand(self, pcoll):
564+
return assert_that(
565+
pcoll | beam.Map(lambda row: beam.Row(**row._asdict())),
566+
equal_to(dicts_to_rows(self._elements)))
567+
557568
@staticmethod
558569
def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
559570
"""Creates a collection containing a specified set of elements.
@@ -810,6 +821,7 @@ def log_and_return(x):
810821
@staticmethod
811822
def create_builtin_provider():
812823
return InlineProvider({
824+
'AssertEqual': YamlProviders.AssertEqual,
813825
'Create': YamlProviders.create,
814826
'LogForTesting': YamlProviders.log_for_testing,
815827
'PyTransform': YamlProviders.fully_qualified_named_transform,

0 commit comments

Comments
 (0)