Skip to content

Commit 86d4aec

Browse files
authored
Added integration test for AutoML (#477)
* added integration test for AutoML
1 parent a7fe2b2 commit 86d4aec

File tree

2 files changed

+89
-24
lines changed

2 files changed

+89
-24
lines changed

integration/test_ml.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pytest
2424

25+
import firebase_admin
2526
from firebase_admin import exceptions
2627
from firebase_admin import ml
2728
from tests import testutils
@@ -34,6 +35,11 @@
3435
except ImportError:
3536
_TF_ENABLED = False
3637

38+
try:
39+
from google.cloud import automl_v1
40+
_AUTOML_ENABLED = True
41+
except ImportError:
42+
_AUTOML_ENABLED = False
3743

3844
def _random_identifier(prefix):
3945
#pylint: disable=unused-variable
@@ -62,7 +68,6 @@ def _random_identifier(prefix):
6268
'file_name': 'invalid_model.tflite'
6369
}
6470

65-
6671
@pytest.fixture
6772
def firebase_model(request):
6873
args = request.param
@@ -101,6 +106,7 @@ def _clean_up_model(model):
101106
try:
102107
# Try to delete the model.
103108
# Some tests delete the model as part of the test.
109+
model.wait_for_unlocked()
104110
ml.delete_model(model.model_id)
105111
except exceptions.NotFoundError:
106112
pass
@@ -132,35 +138,45 @@ def check_model(model, args):
132138
assert model.locked is False
133139
assert model.etag is not None
134140

141+
# Model Format Checks
135142

136-
def check_model_format(model, has_model_format=False, validation_error=None):
137-
if has_model_format:
138-
assert model.validation_error == validation_error
139-
assert model.published is False
140-
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
141-
if validation_error:
142-
assert model.model_format.size_bytes is None
143-
assert model.model_hash is None
144-
else:
145-
assert model.model_format.size_bytes is not None
146-
assert model.model_hash is not None
147-
else:
148-
assert model.model_format is None
149-
assert model.validation_error == 'No model file has been uploaded.'
150-
assert model.published is False
143+
def check_no_model_format(model):
144+
assert model.model_format is None
145+
assert model.validation_error == 'No model file has been uploaded.'
146+
assert model.published is False
147+
assert model.model_hash is None
148+
149+
150+
def check_tflite_gcs_format(model, validation_error=None):
151+
assert model.validation_error == validation_error
152+
assert model.published is False
153+
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
154+
if validation_error:
155+
assert model.model_format.size_bytes is None
151156
assert model.model_hash is None
157+
else:
158+
assert model.model_format.size_bytes is not None
159+
assert model.model_hash is not None
160+
161+
162+
def check_tflite_automl_format(model):
163+
assert model.validation_error is None
164+
assert model.published is False
165+
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
166+
# Automl models don't have validation errors since they are references
167+
# to valid automl models.
152168

153169

154170
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
155171
def test_create_simple_model(firebase_model):
156172
check_model(firebase_model, NAME_AND_TAGS_ARGS)
157-
check_model_format(firebase_model)
173+
check_no_model_format(firebase_model)
158174

159175

160176
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
161177
def test_create_full_model(firebase_model):
162178
check_model(firebase_model, FULL_MODEL_ARGS)
163-
check_model_format(firebase_model, True)
179+
check_tflite_gcs_format(firebase_model)
164180

165181

166182
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
@@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
175191
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
176192
def test_create_invalid_model(firebase_model):
177193
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
178-
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
194+
check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format')
179195

180196

181197
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
182198
def test_get_model(firebase_model):
183199
get_model = ml.get_model(firebase_model.model_id)
184200
check_model(get_model, NAME_AND_TAGS_ARGS)
185-
check_model_format(get_model)
201+
check_no_model_format(get_model)
186202

187203

188204
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -201,12 +217,12 @@ def test_update_model(firebase_model):
201217
firebase_model.display_name = new_model_name
202218
updated_model = ml.update_model(firebase_model)
203219
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
204-
check_model_format(updated_model)
220+
check_no_model_format(updated_model)
205221

206222
# Second call with same model does not cause error
207223
updated_model2 = ml.update_model(updated_model)
208224
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
209-
check_model_format(updated_model2)
225+
check_no_model_format(updated_model2)
210226

211227

212228
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -290,7 +306,7 @@ def test_delete_model(firebase_model):
290306

291307
# Test tensor flow conversion functions if tensor flow is enabled.
292308
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
293-
#'pip install tensorflow==2.0.0b' for version 2 etc.
309+
#'pip install tensorflow==2.2.0' for version 2.2.0 etc.
294310

295311

296312
def _clean_up_directory(save_dir):
@@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
334350
_clean_up_directory(parent)
335351

336352

353+
337354
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
338355
def test_from_keras_model(keras_model):
339356
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
@@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):
348365

349366
try:
350367
check_model(created_model, {'display_name': model.display_name})
351-
check_model_format(created_model, True)
368+
check_tflite_gcs_format(created_model)
352369
finally:
353370
_clean_up_model(created_model)
354371

@@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
371388
assert created_model.validation_error is None
372389
finally:
373390
_clean_up_model(created_model)
391+
392+
393+
# Test AutoML functionality if AutoML is enabled.
394+
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
395+
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
396+
# successful test. (Test is skipped otherwise)
397+
398+
@pytest.fixture
399+
def automl_model():
400+
assert _AUTOML_ENABLED
401+
402+
# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
403+
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
404+
# the test.
405+
automl_client = automl_v1.AutoMlClient()
406+
project_id = firebase_admin.get_app().project_id
407+
parent = automl_client.location_path(project_id, 'us-central1')
408+
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
409+
# Expecting exactly one. (Ok to use last one if somehow more than 1)
410+
automl_ref = None
411+
for model in models:
412+
automl_ref = model.name
413+
414+
# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
415+
if automl_ref is None:
416+
pytest.skip("No pre-existing AutoML model found. Skipping test")
417+
418+
source = ml.TFLiteAutoMlSource(automl_ref)
419+
tflite_format = ml.TFLiteFormat(model_source=source)
420+
ml_model = ml.Model(
421+
display_name=_random_identifier('TestModel_automl_'),
422+
tags=['test_automl'],
423+
model_format=tflite_format)
424+
model = ml.create_model(model=ml_model)
425+
yield model
426+
_clean_up_model(model)
427+
428+
@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
429+
def test_automl_model(automl_model):
430+
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
431+
automl_model.wait_for_unlocked()
432+
433+
check_model(automl_model, {
434+
'display_name': automl_model.display_name,
435+
'tags': ['test_automl'],
436+
})
437+
check_tflite_automl_format(automl_model)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1
77
cachecontrol >= 0.12.6
88
google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy'
99
google-api-python-client >= 1.7.8
10+
google-auth == 1.18.0 # temporary workaround
1011
google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy'
1112
google-cloud-storage >= 1.18.0

0 commit comments

Comments
 (0)