Skip to content

Added integration test for AutoML #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions integration/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils
Expand All @@ -34,6 +35,11 @@
except ImportError:
_TF_ENABLED = False

try:
from google.cloud import automl_v1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is merged to master, let's also update the .github/workflows/release.yml to install this package, and add the Auto ML resource to the project Kobayashi Maru.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

_AUTOML_ENABLED = True
except ImportError:
_AUTOML_ENABLED = False

def _random_identifier(prefix):
#pylint: disable=unused-variable
Expand Down Expand Up @@ -62,7 +68,6 @@ def _random_identifier(prefix):
'file_name': 'invalid_model.tflite'
}


@pytest.fixture
def firebase_model(request):
args = request.param
Expand Down Expand Up @@ -101,6 +106,7 @@ def _clean_up_model(model):
try:
# Try to delete the model.
# Some tests delete the model as part of the test.
model.wait_for_unlocked()
ml.delete_model(model.model_id)
except exceptions.NotFoundError:
pass
Expand Down Expand Up @@ -132,35 +138,45 @@ def check_model(model, args):
assert model.locked is False
assert model.etag is not None

# Model Format Checks

def check_model_format(model, has_model_format=False, validation_error=None):
if has_model_format:
assert model.validation_error == validation_error
assert model.published is False
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None
else:
assert model.model_format is None
assert model.validation_error == 'No model file has been uploaded.'
assert model.published is False
def check_no_model_format(model):
assert model.model_format is None
assert model.validation_error == 'No model file has been uploaded.'
assert model.published is False
assert model.model_hash is None


def check_tflite_gcs_format(model, validation_error=None):
assert model.validation_error == validation_error
assert model.published is False
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None


def check_tflite_automl_format(model):
assert model.validation_error is None
assert model.published is False
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
# Automl models don't have validation errors since they are references
# to valid automl models.


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_create_simple_model(firebase_model):
check_model(firebase_model, NAME_AND_TAGS_ARGS)
check_model_format(firebase_model)
check_no_model_format(firebase_model)


@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
def test_create_full_model(firebase_model):
check_model(firebase_model, FULL_MODEL_ARGS)
check_model_format(firebase_model, True)
check_tflite_gcs_format(firebase_model)


@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
Expand All @@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
def test_create_invalid_model(firebase_model):
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format')


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_get_model(firebase_model):
get_model = ml.get_model(firebase_model.model_id)
check_model(get_model, NAME_AND_TAGS_ARGS)
check_model_format(get_model)
check_no_model_format(get_model)


@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
Expand All @@ -201,12 +217,12 @@ def test_update_model(firebase_model):
firebase_model.display_name = new_model_name
updated_model = ml.update_model(firebase_model)
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
check_model_format(updated_model)
check_no_model_format(updated_model)

# Second call with same model does not cause error
updated_model2 = ml.update_model(updated_model)
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
check_model_format(updated_model2)
check_no_model_format(updated_model2)


@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
Expand Down Expand Up @@ -290,7 +306,7 @@ def test_delete_model(firebase_model):

# Test tensor flow conversion functions if tensor flow is enabled.
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
#'pip install tensorflow==2.0.0b' for version 2 etc.
#'pip install tensorflow==2.2.0' for version 2.2.0 etc.


def _clean_up_directory(save_dir):
Expand Down Expand Up @@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
_clean_up_directory(parent)



@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
def test_from_keras_model(keras_model):
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
Expand All @@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):

try:
check_model(created_model, {'display_name': model.display_name})
check_model_format(created_model, True)
check_tflite_gcs_format(created_model)
finally:
_clean_up_model(created_model)

Expand All @@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
assert created_model.validation_error is None
finally:
_clean_up_model(created_model)


# Test AutoML functionality if AutoML is enabled.
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
# successful test. (Test is skipped otherwise)

@pytest.fixture
def automl_model():
assert _AUTOML_ENABLED

# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
# the test.
automl_client = automl_v1.AutoMlClient()
project_id = firebase_admin.get_app().project_id
parent = automl_client.location_path(project_id, 'us-central1')
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
# Expecting exactly one. (Ok to use last one if somehow more than 1)
automl_ref = None
for model in models:
automl_ref = model.name

# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
if automl_ref is None:
pytest.skip("No pre-existing AutoML model found. Skipping test")

source = ml.TFLiteAutoMlSource(automl_ref)
tflite_format = ml.TFLiteFormat(model_source=source)
ml_model = ml.Model(
display_name=_random_identifier('TestModel_automl_'),
tags=['test_automl'],
model_format=tflite_format)
model = ml.create_model(model=ml_model)
yield model
_clean_up_model(model)

@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
def test_automl_model(automl_model):
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
automl_model.wait_for_unlocked()

check_model(automl_model, {
'display_name': automl_model.display_name,
'tags': ['test_automl'],
})
check_tflite_automl_format(automl_model)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1
cachecontrol >= 0.12.6
google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy'
google-api-python-client >= 1.7.8
google-auth == 1.18.0 # temporary workaround
google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy'
google-cloud-storage >= 1.18.0