Skip to content

feat(ml): Adding Firebase ML support for AutoML models #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f69e14c
Introduced the exceptions module (#296)
hiranya911 Jun 12, 2019
2879a22
Migrating FCM Send APIs to the New Exceptions (#297)
hiranya911 Jun 20, 2019
fa843f3
Migrated remaining messaging APIs to new error types (#298)
hiranya911 Jun 26, 2019
b27216b
Introducing TokenSignError to represent custom token creation errors …
hiranya911 Jul 5, 2019
99929ed
Raising FirebaseError from create_session_cookie() API (#306)
hiranya911 Jul 18, 2019
9fb0766
Introducing UserNotFoundError type (#309)
hiranya911 Jul 18, 2019
8a0cf08
New error handling support in create/update/delete user APIs (#311)
hiranya911 Jul 26, 2019
29c8b7a
Error handling improvements in email action link APIs (#312)
hiranya911 Jul 31, 2019
3361452
Project management API migrated to new error types (#314)
hiranya911 Jul 31, 2019
dbb6970
Error handling updated for remaining user_mgt APIs (#315)
hiranya911 Aug 2, 2019
baf4991
Merged with master
hiranya911 Aug 3, 2019
1210723
Migrated token verification APIs to new exception types (#317)
hiranya911 Aug 5, 2019
299e808
Migrated the db module to the new exception types (#318)
hiranya911 Aug 5, 2019
030f6e6
Adding a few overlooked error types (#319)
hiranya911 Aug 8, 2019
7974c05
Removing the ability to delete user properties by passing None (#320)
hiranya911 Aug 9, 2019
dd3c4bd
Adding beginning of _MLKitService (#323)
ifielker Aug 14, 2019
65f64c0
Firebase ML Kit Get Model API implementation (#326)
ifielker Aug 19, 2019
a84d3f6
Firebase ML Kit Delete Model API implementation (#327)
ifielker Aug 19, 2019
a247f13
Firebase ML Kit List Models API implementation (#331)
ifielker Aug 21, 2019
4618b1e
Implementation of Model, ModelFormat, TFLiteModelSource and subclasse…
ifielker Aug 29, 2019
e5cf14a
Firebase ML Kit Create Model API implementation (#337)
ifielker Sep 11, 2019
2a3be77
Firebase ML Kit Update Model API implementation (#343)
ifielker Sep 11, 2019
0344172
Firebase ML Kit Publish and Unpublish Implementation (#345)
ifielker Sep 11, 2019
cd5e82a
Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation…
ifielker Sep 17, 2019
7b4731f
Quick pass at filling in missing docstrings (#367)
kevinthecheung Nov 18, 2019
c6dafbb
Modify Operation Handling to not require a name for Done Operations (…
ifielker Dec 10, 2019
079c7e1
rename from mlkit to ml (#373)
ifielker Dec 10, 2019
a13d2a7
Adding File naming capability to from_saved_model and from_keras_mode…
ifielker Dec 13, 2019
b133bbb
Firebase ML Modify Operation Handling Code to match rpc codes not htm…
ifielker Jan 23, 2020
8770448
Mlkit fix date handling2 (#391)
ifielker Jan 27, 2020
cf748c8
Firebase Ml Fix upload file naming (#392)
ifielker Jan 27, 2020
0b70687
Integration tests for Firebase ML (#394)
ifielker Jan 30, 2020
c0094ed
Merged with master
hiranya911 Jan 30, 2020
7295ea4
Fixing lint errors for Py3 (#401)
hiranya911 Jan 30, 2020
bcefca8
Modifying operation handling to support backend changes (#423)
ifielker Mar 20, 2020
e49add8
Firebase ML Changing service endpoint (#421)
ifielker Mar 20, 2020
c4275be
added support for automl-models (#428)
ifielker Mar 27, 2020
a7fe2b2
merged master
ifielker Jul 13, 2020
86d4aec
Added integration test for AutoML (#477)
ifielker Jul 13, 2020
4d6daec
Pydoc edits (#480)
kevinthecheung Aug 12, 2020
e100ea1
Merge branch 'master' into mlkit-automl
hiranya911 Sep 10, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_AUTO_ML_MODEL_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
r'models/(?P<model_id>[A-Za-z0-9]+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
Expand All @@ -75,7 +78,7 @@ def _get_ml_service(app):


def create_model(model, app=None):
"""Creates a model in Firebase ML.
"""Creates a model in the current Firebase project.

Args:
model: An ml.Model to create.
Expand All @@ -89,7 +92,7 @@ def create_model(model, app=None):


def update_model(model, app=None):
"""Updates a model in Firebase ML.
"""Updates a model's metadata or model file.

Args:
model: The ml.Model to update.
Expand All @@ -103,7 +106,9 @@ def update_model(model, app=None):


def publish_model(model_id, app=None):
"""Publishes a model in Firebase ML.
"""Publishes a Firebase ML model.

A published model can be downloaded to client apps.

Args:
model_id: The id of the model to publish.
Expand All @@ -117,7 +122,7 @@ def publish_model(model_id, app=None):


def unpublish_model(model_id, app=None):
"""Unpublishes a model in Firebase ML.
"""Unpublishes a Firebase ML model.

Args:
model_id: The id of the model to unpublish.
Expand All @@ -131,7 +136,7 @@ def unpublish_model(model_id, app=None):


def get_model(model_id, app=None):
"""Gets a model from Firebase ML.
"""Gets the model specified by the given ID.

Args:
model_id: The id of the model to get.
Expand All @@ -145,7 +150,7 @@ def get_model(model_id, app=None):


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists models from Firebase ML.
"""Lists the current project's models.

Args:
list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models.
Expand All @@ -164,7 +169,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None):


def delete_model(model_id, app=None):
"""Deletes a model from Firebase ML.
"""Deletes a model from the current project.

Args:
model_id: The id of the model you wish to delete.
Expand Down Expand Up @@ -363,15 +368,10 @@ def __init__(self, model_source=None):
def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
model_source = None
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
tflite_format = TFLiteFormat(model_source=model_source)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format


def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
Expand All @@ -381,6 +381,16 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@staticmethod
def _init_model_source(data):
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
auto_ml_model = data.pop('automlModel', None)
if auto_ml_model:
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
return None

@property
def model_source(self):
"""The TF Lite model's location."""
Expand Down Expand Up @@ -593,8 +603,38 @@ def as_dict(self, for_upload=False):
return {'gcsTfliteUri': self._gcs_tflite_uri}


class TFLiteAutoMlSource(TFLiteModelSource):
"""TFLite model source representing a tflite model created with AutoML."""

def __init__(self, auto_ml_model, app=None):
self._app = app
self.auto_ml_model = auto_ml_model

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.auto_ml_model == other.auto_ml_model
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def auto_ml_model(self):
"""Resource name of the model, created by the AutoML API or Cloud console."""
return self._auto_ml_model

@auto_ml_model.setter
def auto_ml_model(self, auto_ml_model):
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
# Upload is irrelevant for auto_ml models
return {'automlModel': self._auto_ml_model}


class ListModelsPage:
"""Represents a page of models in a firebase project.
"""Represents a page of models in a Firebase project.

Provides methods for traversing the models included in this page, as well as
retrieving subsequent pages of models. The iterator returned by
Expand Down Expand Up @@ -740,6 +780,11 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri

def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
return model


def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
Expand Down
112 changes: 88 additions & 24 deletions integration/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils
Expand All @@ -34,6 +35,11 @@
except ImportError:
_TF_ENABLED = False

try:
from google.cloud import automl_v1
_AUTOML_ENABLED = True
except ImportError:
_AUTOML_ENABLED = False

def _random_identifier(prefix):
#pylint: disable=unused-variable
Expand Down Expand Up @@ -62,7 +68,6 @@ def _random_identifier(prefix):
'file_name': 'invalid_model.tflite'
}


@pytest.fixture
def firebase_model(request):
args = request.param
Expand Down Expand Up @@ -101,6 +106,7 @@ def _clean_up_model(model):
try:
# Try to delete the model.
# Some tests delete the model as part of the test.
model.wait_for_unlocked()
ml.delete_model(model.model_id)
except exceptions.NotFoundError:
pass
Expand Down Expand Up @@ -132,35 +138,45 @@ def check_model(model, args):
assert model.locked is False
assert model.etag is not None

# Model Format Checks

def check_model_format(model, has_model_format=False, validation_error=None):
if has_model_format:
assert model.validation_error == validation_error
assert model.published is False
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None
else:
assert model.model_format is None
assert model.validation_error == 'No model file has been uploaded.'
assert model.published is False
def check_no_model_format(model):
assert model.model_format is None
assert model.validation_error == 'No model file has been uploaded.'
assert model.published is False
assert model.model_hash is None


def check_tflite_gcs_format(model, validation_error=None):
assert model.validation_error == validation_error
assert model.published is False
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None


def check_tflite_automl_format(model):
assert model.validation_error is None
assert model.published is False
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
# Automl models don't have validation errors since they are references
# to valid automl models.


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_create_simple_model(firebase_model):
check_model(firebase_model, NAME_AND_TAGS_ARGS)
check_model_format(firebase_model)
check_no_model_format(firebase_model)


@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
def test_create_full_model(firebase_model):
check_model(firebase_model, FULL_MODEL_ARGS)
check_model_format(firebase_model, True)
check_tflite_gcs_format(firebase_model)


@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
Expand All @@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
def test_create_invalid_model(firebase_model):
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format')


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_get_model(firebase_model):
get_model = ml.get_model(firebase_model.model_id)
check_model(get_model, NAME_AND_TAGS_ARGS)
check_model_format(get_model)
check_no_model_format(get_model)


@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
Expand All @@ -201,12 +217,12 @@ def test_update_model(firebase_model):
firebase_model.display_name = new_model_name
updated_model = ml.update_model(firebase_model)
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
check_model_format(updated_model)
check_no_model_format(updated_model)

# Second call with same model does not cause error
updated_model2 = ml.update_model(updated_model)
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
check_model_format(updated_model2)
check_no_model_format(updated_model2)


@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
Expand Down Expand Up @@ -290,7 +306,7 @@ def test_delete_model(firebase_model):

# Test tensor flow conversion functions if tensor flow is enabled.
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
#'pip install tensorflow==2.0.0b' for version 2 etc.
#'pip install tensorflow==2.2.0' for version 2.2.0 etc.


def _clean_up_directory(save_dir):
Expand Down Expand Up @@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
_clean_up_directory(parent)



@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
def test_from_keras_model(keras_model):
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
Expand All @@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):

try:
check_model(created_model, {'display_name': model.display_name})
check_model_format(created_model, True)
check_tflite_gcs_format(created_model)
finally:
_clean_up_model(created_model)

Expand All @@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
assert created_model.validation_error is None
finally:
_clean_up_model(created_model)


# Test AutoML functionality if AutoML is enabled.
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
# successful test. (Test is skipped otherwise)

@pytest.fixture
def automl_model():
assert _AUTOML_ENABLED

# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
# the test.
automl_client = automl_v1.AutoMlClient()
project_id = firebase_admin.get_app().project_id
parent = automl_client.location_path(project_id, 'us-central1')
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
# Expecting exactly one. (Ok to use last one if somehow more than 1)
automl_ref = None
for model in models:
automl_ref = model.name

# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
if automl_ref is None:
pytest.skip("No pre-existing AutoML model found. Skipping test")

source = ml.TFLiteAutoMlSource(automl_ref)
tflite_format = ml.TFLiteFormat(model_source=source)
ml_model = ml.Model(
display_name=_random_identifier('TestModel_automl_'),
tags=['test_automl'],
model_format=tflite_format)
model = ml.create_model(model=ml_model)
yield model
_clean_up_model(model)

@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
def test_automl_model(automl_model):
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
automl_model.wait_for_unlocked()

check_model(automl_model, {
'display_name': automl_model.display_name,
'tags': ['test_automl'],
})
check_tflite_automl_format(automl_model)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1
cachecontrol >= 0.12.6
google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy'
google-api-python-client >= 1.7.8
google-auth == 1.18.0 # temporary workaround
google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy'
google-cloud-storage >= 1.18.0
Loading