Skip to content

Commit 9802a06

Browse files
authored
Merge pull request #34569 Add the ability to unit test YAML pipelines.
This follows the ideas at https://s.apache.org/beam-yaml-testing.
2 parents da8c5e0 + 3f3dc4f commit 9802a06

File tree

12 files changed

+1655
-17
lines changed

12 files changed

+1655
-17
lines changed

sdks/python/apache_beam/yaml/main.py

+130-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import argparse
1919
import contextlib
2020
import json
21+
import os
22+
import sys
23+
import unittest
2124

2225
import yaml
2326

@@ -26,7 +29,9 @@
2629
from apache_beam.transforms import resources
2730
from apache_beam.typehints.schemas import LogicalType
2831
from apache_beam.typehints.schemas import MillisInstant
32+
from apache_beam.yaml import yaml_testing
2933
from apache_beam.yaml import yaml_transform
34+
from apache_beam.yaml import yaml_utils
3035

3136

3237
def _preparse_jinja_flags(argv):
@@ -90,6 +95,28 @@ def _parse_arguments(argv):
9095
type=json.loads,
9196
help='A json dict of variables used when invoking the jinja preprocessor '
9297
'on the provided yaml pipeline.')
98+
parser.add_argument(
99+
'--test',
100+
action=argparse.BooleanOptionalAction,
101+
help='Run the tests associated with the given pipeline, rather than the '
102+
'pipeline itself.')
103+
parser.add_argument(
104+
'--fix_tests',
105+
action=argparse.BooleanOptionalAction,
106+
help='Update failing test expectations to match the actual ouput. '
107+
'Requires --test_suite if the pipeline uses jinja formatting.')
108+
parser.add_argument(
109+
'--create_test',
110+
action=argparse.BooleanOptionalAction,
111+
help='Automatically creates a regression test for the given pipeline, '
112+
'adding it to the pipeline spec or test suite dependon on whether '
113+
'--test_suite is given. '
114+
'Requires --test_suite if the pipeline uses jinja formatting.')
115+
parser.add_argument(
116+
'--test_suite',
117+
help='Run the given tests against the given pipeline, rather than the '
118+
'pipeline itself. '
119+
'Should be a file containing a list of yaml test specifications.')
93120
return parser.parse_known_args(argv)
94121

95122

@@ -130,12 +157,109 @@ def run(argv=None):
130157
print('Running pipeline...')
131158

132159

133-
def build_pipeline_components_from_argv(argv):
160+
def run_tests(argv=None, exit=True):
161+
known_args, pipeline_args, _, pipeline_yaml = _build_pipeline_yaml_from_argv(
162+
argv)
163+
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)
164+
options = _build_pipeline_options(pipeline_spec, pipeline_args)
165+
166+
if known_args.create_test and known_args.fix_tests:
167+
raise ValueError(
168+
'At most one of --create_test and --fix_tests may be specified.')
169+
elif known_args.create_test:
170+
result = unittest.TestResult()
171+
tests = []
172+
else:
173+
if known_args.test_suite:
174+
with open(known_args.test_suite) as fin:
175+
test_suite_holder = yaml.load(
176+
fin, Loader=yaml_transform.SafeLineLoader) or {}
177+
else:
178+
test_suite_holder = pipeline_spec
179+
test_specs = test_suite_holder.get('tests', [])
180+
if not isinstance(test_specs, list):
181+
raise TypeError('tests attribute must be a list of test specifications.')
182+
elif not test_specs:
183+
raise RuntimeError(
184+
'No tests found. '
185+
"If you haven't added a set of tests yet, you can get started by "
186+
'running your pipeline with the --create_test flag enabled.')
187+
188+
with _fix_xlang_instant_coding():
189+
tests = [
190+
yaml_testing.YamlTestCase(
191+
pipeline_spec, test_spec, options, known_args.fix_tests)
192+
for test_spec in test_specs
193+
]
194+
suite = unittest.TestSuite(tests)
195+
result = unittest.TextTestRunner().run(suite)
196+
197+
if known_args.fix_tests or known_args.create_test:
198+
update_tests(known_args, pipeline_yaml, pipeline_spec, options, tests)
199+
200+
if exit:
201+
# emulates unittest.main()
202+
sys.exit(0 if result.wasSuccessful() else 1)
203+
else:
204+
if not result.wasSuccessful():
205+
raise RuntimeError(result)
206+
207+
208+
def update_tests(known_args, pipeline_yaml, pipeline_spec, options, tests):
209+
if known_args.test_suite:
210+
path = known_args.test_suite
211+
if not os.path.exists(path) and known_args.create_test:
212+
with open(path, 'w') as fout:
213+
fout.write('tests: []')
214+
elif known_args.yaml_pipeline_file:
215+
path = known_args.yaml_pipeline_file
216+
else:
217+
raise RuntimeError(
218+
'Test fixing only supported for file-backed tests. '
219+
'Please use the --test_suite flag.')
220+
with open(path) as fin:
221+
original_yaml = fin.read()
222+
if path == known_args.yaml_pipeline_file and pipeline_yaml.strip(
223+
) != original_yaml.strip():
224+
raise RuntimeError(
225+
'In-file test fixing not yet supported for templated pipelines. '
226+
'Please use the --test_suite flag.')
227+
updated_spec = yaml.load(original_yaml, Loader=yaml.SafeLoader) or {}
228+
229+
if known_args.fix_tests:
230+
updated_spec['tests'] = [test.fixed_test() for test in tests]
231+
232+
if known_args.create_test:
233+
if 'tests' not in updated_spec:
234+
updated_spec['tests'] = []
235+
updated_spec['tests'].append(
236+
yaml_testing.create_test(pipeline_spec, options))
237+
238+
updated_yaml = yaml_utils.patch_yaml(original_yaml, updated_spec)
239+
with open(path, 'w') as fout:
240+
fout.write(updated_yaml)
241+
242+
243+
def _build_pipeline_yaml_from_argv(argv):
134244
argv = _preparse_jinja_flags(argv)
135245
known_args, pipeline_args = _parse_arguments(argv)
136246
pipeline_template = _pipeline_spec_from_args(known_args)
137247
pipeline_yaml = yaml_transform.expand_jinja(
138248
pipeline_template, known_args.jinja_variables or {})
249+
return known_args, pipeline_args, pipeline_template, pipeline_yaml
250+
251+
252+
def _build_pipeline_options(pipeline_spec, pipeline_args):
253+
return beam.options.pipeline_options.PipelineOptions(
254+
pipeline_args,
255+
pickle_library='cloudpickle',
256+
**yaml_transform.SafeLineLoader.strip_metadata(
257+
pipeline_spec.get('options', {})))
258+
259+
260+
def build_pipeline_components_from_argv(argv):
261+
(known_args, pipeline_args, pipeline_template,
262+
pipeline_yaml) = _build_pipeline_yaml_from_argv(argv)
139263
display_data = {
140264
'yaml': pipeline_yaml,
141265
'yaml_jinja_template': pipeline_template,
@@ -154,11 +278,7 @@ def build_pipeline_components_from_yaml(
154278
pipeline_yaml, pipeline_args, validate_schema='generic', pipeline_path=''):
155279
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)
156280

157-
options = beam.options.pipeline_options.PipelineOptions(
158-
pipeline_args,
159-
pickle_library='cloudpickle',
160-
**yaml_transform.SafeLineLoader.strip_metadata(
161-
pipeline_spec.get('options', {})))
281+
options = _build_pipeline_options(pipeline_spec, pipeline_args)
162282

163283
def constructor(root):
164284
if 'resource_hints' in pipeline_spec.get('pipeline', {}):
@@ -180,4 +300,7 @@ def constructor(root):
180300
if __name__ == '__main__':
181301
import logging
182302
logging.getLogger().setLevel(logging.INFO)
183-
run()
303+
if '--test' in sys.argv:
304+
run_tests()
305+
else:
306+
run()

sdks/python/apache_beam/yaml/main_test.py

+114
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,44 @@
3838
- type: WriteToText
3939
config:
4040
path: PATH
41+
42+
tests:
43+
- name: InlineTest
44+
mock_outputs:
45+
- name: Create
46+
elements: ['a', 'b', 'c']
47+
expected_inputs:
48+
- name: WriteToText
49+
elements:
50+
- {element: a}
51+
- {element: b}
52+
- {element: c}
53+
'''
54+
55+
PASSING_TEST_SUITE = '''
56+
tests:
57+
- name: ExternalTest # comment
58+
mock_outputs:
59+
- name: Create
60+
elements: ['a', 'b', 'c']
61+
expected_inputs:
62+
- name: WriteToText
63+
elements:
64+
- element: a
65+
- element: b
66+
- element: c
67+
'''
68+
69+
FAILING_TEST_SUITE = '''
70+
tests:
71+
- name: ExternalTest # comment
72+
mock_outputs:
73+
- name: Create
74+
elements: ['a', 'b', 'c']
75+
expected_inputs:
76+
- name: WriteToText
77+
elements:
78+
- element: x
4179
'''
4280

4381

@@ -113,6 +151,82 @@ def test_jinja_datetime(self):
113151
self.assertEqual(
114152
fin.read().strip(), datetime.datetime.now().strftime("%Y-%m-%d"))
115153

154+
def test_inline_test_specs(self):
155+
main.run_tests(['--yaml_pipeline', TEST_PIPELINE, '--test'], exit=False)
156+
157+
def test_external_test_specs(self):
158+
with tempfile.TemporaryDirectory() as tmpdir:
159+
good_suite = os.path.join(tmpdir, 'good.yaml')
160+
with open(good_suite, 'w') as fout:
161+
fout.write(PASSING_TEST_SUITE)
162+
bad_suite = os.path.join(tmpdir, 'bad.yaml')
163+
with open(bad_suite, 'w') as fout:
164+
fout.write(FAILING_TEST_SUITE)
165+
166+
# Must pass.
167+
main.run_tests([
168+
'--yaml_pipeline',
169+
TEST_PIPELINE,
170+
'--test_suite',
171+
good_suite,
172+
],
173+
exit=False)
174+
175+
# Must fail. (Ensures testing is not a no-op.)
176+
with self.assertRaisesRegex(Exception, 'errors=1 failures=0'):
177+
main.run_tests([
178+
'--yaml_pipeline',
179+
TEST_PIPELINE,
180+
'--test_suite',
181+
bad_suite,
182+
],
183+
exit=False)
184+
185+
def test_fix_suite(self):
186+
with tempfile.TemporaryDirectory() as tmpdir:
187+
test_suite = os.path.join(tmpdir, 'tests.yaml')
188+
with open(test_suite, 'w') as fout:
189+
fout.write(FAILING_TEST_SUITE)
190+
191+
main.run_tests([
192+
'--yaml_pipeline',
193+
TEST_PIPELINE,
194+
'--test_suite',
195+
test_suite,
196+
'--fix_tests'
197+
],
198+
exit=False)
199+
200+
with open(test_suite) as fin:
201+
self.assertEqual(fin.read(), PASSING_TEST_SUITE)
202+
203+
def test_create_test(self):
204+
with tempfile.TemporaryDirectory() as tmpdir:
205+
test_suite = os.path.join(tmpdir, 'tests.yaml')
206+
with open(test_suite, 'w') as fout:
207+
fout.write('')
208+
209+
main.run_tests([
210+
'--yaml_pipeline',
211+
TEST_PIPELINE.replace('ELEMENT', 'x'),
212+
'--test_suite',
213+
test_suite,
214+
'--create_test'
215+
],
216+
exit=False)
217+
218+
with open(test_suite) as fin:
219+
self.assertEqual(
220+
fin.read(),
221+
'''
222+
tests:
223+
- mock_outputs: []
224+
expected_inputs:
225+
- name: WriteToText
226+
elements:
227+
- element: x
228+
'''.lstrip())
229+
116230

117231
if __name__ == '__main__':
118232
logging.getLogger().setLevel(logging.INFO)

sdks/python/apache_beam/yaml/readme_test.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from apache_beam.options.pipeline_options import PipelineOptions
3535
from apache_beam.typehints import trivial_inference
3636
from apache_beam.yaml import yaml_provider
37+
from apache_beam.yaml import yaml_testing
3738
from apache_beam.yaml import yaml_transform
39+
from apache_beam.yaml import yaml_utils
3840

3941

4042
class FakeSql(beam.PTransform):
@@ -288,6 +290,7 @@ def extract_name(input_spec):
288290
return input_spec.get('name', input_spec.get('type'))
289291

290292
code_lines = None
293+
last_pipeline = None
291294
for ix, line in enumerate(markdown_lines):
292295
line = line.rstrip()
293296
if line == '```':
@@ -320,12 +323,30 @@ def extract_name(input_spec):
320323
] + [' ' + line for line in code_lines]
321324
if code_lines[0] == 'pipeline:':
322325
yaml_pipeline = '\n'.join(code_lines)
323-
if 'providers:' in yaml_pipeline:
326+
last_pipeline = yaml_pipeline
327+
if 'providers:' in yaml_pipeline or 'tests:' in yaml_pipeline:
324328
test_type = 'PARSE'
325329
yield test_name, create_test_method(
326330
test_type,
327331
test_name,
328332
yaml_pipeline)
333+
if 'tests:' in code_lines:
334+
test_spec = '\n'.join(code_lines)
335+
if code_lines[0] == 'pipeline:':
336+
yaml_pipeline = '\n'.join(code_lines)
337+
else:
338+
yaml_pipeline = last_pipeline
339+
for sub_ix, test_spec in enumerate(yaml.load(
340+
'\n'.join(code_lines),
341+
Loader=yaml_utils.SafeLineLoader)['tests']):
342+
suffix = test_spec.get('name', str(sub_ix))
343+
yield (
344+
test_name + '_' + suffix,
345+
# The yp=... ts=... is to capture the looped closure values.
346+
lambda _,
347+
yp=yaml_pipeline,
348+
ts=test_spec: yaml_testing.run_test(yp, ts))
349+
329350
code_lines = None
330351
elif code_lines is not None:
331352
code_lines.append(line)
@@ -358,6 +379,9 @@ def createTestSuite(name, path):
358379
JoinTest = createTestSuite(
359380
'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md'))
360381

382+
TestingTest = createTestSuite(
383+
'TestingTest', os.path.join(YAML_DOCS_DIR, 'yaml-testing.md'))
384+
361385
if __name__ == '__main__':
362386
parser = argparse.ArgumentParser()
363387
parser.add_argument('--render_dir', default=None)

0 commit comments

Comments
 (0)