|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +"""Runs integration tests in the tests directory.""" |
| 19 | + |
| 20 | +import contextlib |
| 21 | +import copy |
| 22 | +import glob |
| 23 | +import itertools |
| 24 | +import logging |
| 25 | +import os |
| 26 | +import unittest |
| 27 | +import uuid |
| 28 | + |
| 29 | +import mock |
| 30 | +import yaml |
| 31 | + |
| 32 | +import apache_beam as beam |
| 33 | +from apache_beam.io import filesystems |
| 34 | +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper |
| 35 | +from apache_beam.io.gcp.internal.clients import bigquery |
| 36 | +from apache_beam.options.pipeline_options import PipelineOptions |
| 37 | +from apache_beam.utils import python_callable |
| 38 | +from apache_beam.yaml import yaml_provider |
| 39 | +from apache_beam.yaml import yaml_transform |
| 40 | + |
| 41 | + |
| 42 | +@contextlib.contextmanager |
| 43 | +def gcs_temp_dir(bucket): |
| 44 | + gcs_tempdir = bucket + '/yaml-' + str(uuid.uuid4()) |
| 45 | + yield gcs_tempdir |
| 46 | + filesystems.FileSystems.delete([gcs_tempdir]) |
| 47 | + |
| 48 | + |
| 49 | +@contextlib.contextmanager |
| 50 | +def temp_bigquery_table(project, prefix='yaml_bq_it_'): |
| 51 | + bigquery_client = BigQueryWrapper() |
| 52 | + dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex) |
| 53 | + bigquery_client.get_or_create_dataset(project, dataset_id) |
| 54 | + logging.info("Created dataset %s in project %s", dataset_id, project) |
| 55 | + yield f'{project}:{dataset_id}.tmp_table' |
| 56 | + request = bigquery.BigqueryDatasetsDeleteRequest( |
| 57 | + projectId=project, datasetId=dataset_id, deleteContents=True) |
| 58 | + logging.info("Deleting dataset %s in project %s", dataset_id, project) |
| 59 | + bigquery_client.client.datasets.Delete(request) |
| 60 | + |
| 61 | + |
| 62 | +def replace_recursive(spec, vars): |
| 63 | + if isinstance(spec, dict): |
| 64 | + return { |
| 65 | + key: replace_recursive(value, vars) |
| 66 | + for (key, value) in spec.items() |
| 67 | + } |
| 68 | + elif isinstance(spec, list): |
| 69 | + return [replace_recursive(value, vars) for value in spec] |
| 70 | + elif isinstance(spec, str) and '{' in spec: |
| 71 | + try: |
| 72 | + return spec.format(**vars) |
| 73 | + except Exception as exn: |
| 74 | + raise ValueError(f"Error evaluating {spec}: {exn}") from exn |
| 75 | + else: |
| 76 | + return spec |
| 77 | + |
| 78 | + |
| 79 | +def transform_types(spec): |
| 80 | + if spec.get('type', None) in (None, 'composite', 'chain'): |
| 81 | + if 'source' in spec: |
| 82 | + yield from transform_types(spec['source']) |
| 83 | + for t in spec.get('transforms', []): |
| 84 | + yield from transform_types(t) |
| 85 | + if 'sink' in spec: |
| 86 | + yield from transform_types(spec['sink']) |
| 87 | + else: |
| 88 | + yield spec['type'] |
| 89 | + |
| 90 | + |
| 91 | +def provider_sets(spec, require_available=False): |
| 92 | + """For transforms that are vended by multiple providers, yields all possible |
| 93 | + combinations of providers to use. |
| 94 | + """ |
| 95 | + all_transform_types = set.union( |
| 96 | + *( |
| 97 | + set( |
| 98 | + transform_types( |
| 99 | + yaml_transform.preprocess(copy.deepcopy(p['pipeline'])))) |
| 100 | + for p in spec['pipelines'])) |
| 101 | + |
| 102 | + def filter_to_available(t, providers): |
| 103 | + if require_available: |
| 104 | + for p in providers: |
| 105 | + if not p.available(): |
| 106 | + raise ValueError("Provider {p} required for {t} is not available.") |
| 107 | + return providers |
| 108 | + else: |
| 109 | + return [p for p in providers if p.available()] |
| 110 | + |
| 111 | + standard_providers = yaml_provider.standard_providers() |
| 112 | + multiple_providers = { |
| 113 | + t: filter_to_available(t, standard_providers[t]) |
| 114 | + for t in all_transform_types |
| 115 | + if len(filter_to_available(t, standard_providers[t])) > 1 |
| 116 | + } |
| 117 | + if not multiple_providers: |
| 118 | + return 'only', standard_providers |
| 119 | + else: |
| 120 | + names, provider_lists = zip(*sorted(multiple_providers.items())) |
| 121 | + for ix, c in enumerate(itertools.product(*provider_lists)): |
| 122 | + yield ( |
| 123 | + '_'.join( |
| 124 | + n + '_' + type(p.underlying_provider()).__name__ |
| 125 | + for (n, p) in zip(names, c)) + f'_{ix}', |
| 126 | + dict(standard_providers, **{n: [p] |
| 127 | + for (n, p) in zip(names, c)})) |
| 128 | + |
| 129 | + |
| 130 | +def create_test_methods(spec): |
| 131 | + for suffix, providers in provider_sets(spec): |
| 132 | + |
| 133 | + def test(self, providers=providers): # default arg to capture loop value |
| 134 | + vars = {} |
| 135 | + with contextlib.ExitStack() as stack: |
| 136 | + stack.enter_context( |
| 137 | + mock.patch( |
| 138 | + 'apache_beam.yaml.yaml_provider.standard_providers', |
| 139 | + lambda: providers)) |
| 140 | + for fixture in spec.get('fixtures', []): |
| 141 | + vars[fixture['name']] = stack.enter_context( |
| 142 | + python_callable.PythonCallableWithSource. |
| 143 | + load_from_fully_qualified_name(fixture['type'])( |
| 144 | + **yaml_transform.SafeLineLoader.strip_metadata( |
| 145 | + fixture.get('config', {})))) |
| 146 | + for pipeline_spec in spec['pipelines']: |
| 147 | + with beam.Pipeline(options=PipelineOptions( |
| 148 | + pickle_library='cloudpickle', |
| 149 | + **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( |
| 150 | + 'options', {})))) as p: |
| 151 | + yaml_transform.expand_pipeline( |
| 152 | + p, replace_recursive(pipeline_spec, vars)) |
| 153 | + |
| 154 | + yield f'test_{suffix}', test |
| 155 | + |
| 156 | + |
| 157 | +def parse_test_files(filepattern): |
| 158 | + for path in glob.glob(filepattern): |
| 159 | + with open(path) as fin: |
| 160 | + suite_name = os.path.splitext(os.path.basename(path))[0].title() + 'Test' |
| 161 | + print(path, suite_name) |
| 162 | + methods = dict( |
| 163 | + create_test_methods( |
| 164 | + yaml.load(fin, Loader=yaml_transform.SafeLineLoader))) |
| 165 | + globals()[suite_name] = type(suite_name, (unittest.TestCase, ), methods) |
| 166 | + |
| 167 | + |
| 168 | +logging.getLogger().setLevel(logging.INFO) |
| 169 | +parse_test_files(os.path.join(os.path.dirname(__file__), 'tests', '*.yaml')) |
| 170 | + |
| 171 | +if __name__ == '__main__': |
| 172 | + logging.getLogger().setLevel(logging.INFO) |
| 173 | + unittest.main() |
0 commit comments