-
Notifications
You must be signed in to change notification settings - Fork 339
Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
573e7cb
e67bceb
1f018fe
dfe0a37
7704c44
8381ac5
a2e7544
cadd6c6
b02ea22
a0a2411
fc63db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
deleting, publishing and unpublishing Firebase ML Kit models. | ||
""" | ||
|
||
import datetime | ||
import numbers | ||
import re | ||
import requests | ||
import six | ||
|
@@ -28,6 +30,12 @@ | |
|
||
_MLKIT_ATTRIBUTE = '_mlkit' | ||
_MAX_PAGE_SIZE = 100 | ||
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') | ||
_RESOURCE_NAME_PATTERN = re.compile( | ||
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$') | ||
|
||
|
||
def _get_mlkit_service(app): | ||
|
@@ -47,7 +55,7 @@ def _get_mlkit_service(app): | |
|
||
def get_model(model_id, app=None): | ||
mlkit_service = _get_mlkit_service(app) | ||
return Model(mlkit_service.get_model(model_id)) | ||
return Model(**mlkit_service.get_model(model_id)) | ||
|
||
|
||
def list_models(list_filter=None, page_size=None, page_token=None, app=None): | ||
|
@@ -62,29 +70,215 @@ def delete_model(model_id, app=None): | |
|
||
|
||
class Model(object): | ||
"""A Firebase ML Kit Model object.""" | ||
def __init__(self, data): | ||
"""Created from a data dictionary.""" | ||
self._data = data | ||
"""A Firebase ML Kit Model object. | ||
|
||
Args: | ||
display_name: String - The display name of your model - used to identify your model in code. | ||
tags: Optional list of strings associated with your model. Can be used in list queries. | ||
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. | ||
kwargs: A set of keywords returned by an API response. | ||
""" | ||
def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This constructor seems to be doing a lot. Can we simplify as follows?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._data = kwargs | ||
self._model_format = None | ||
tflite_format = self._data.pop('tfliteModel', None) | ||
if tflite_format: | ||
self._model_format = TFLiteFormat(**tflite_format) | ||
if display_name is not None: | ||
self.display_name = display_name | ||
if tags is not None: | ||
self.tags = tags | ||
if model_format is not None: | ||
self.model_format = model_format | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._data == other._data # pylint: disable=protected-access | ||
# pylint: disable=protected-access | ||
return self._data == other._data and self._model_format == other._model_format | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def name(self): | ||
return self._data['name'] | ||
def model_id(self): | ||
if not self._data.get('name'): | ||
return None | ||
_, model_id = _validate_and_parse_name(self._data.get('name')) | ||
return model_id | ||
|
||
@property | ||
def display_name(self): | ||
return self._data['displayName'] | ||
return self._data.get('displayName') | ||
|
||
@display_name.setter | ||
def display_name(self, display_name): | ||
self._data['displayName'] = _validate_display_name(display_name) | ||
return self | ||
|
||
@property | ||
def create_time(self): | ||
"""Returns the creation timestamp""" | ||
create_time = self._data.get('createTime') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._data.get('createTime', {}).get('seconds') There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not create_time: | ||
return None | ||
|
||
seconds = create_time.get('seconds') | ||
if not isinstance(seconds, numbers.Number): | ||
return None | ||
|
||
return datetime.datetime.fromtimestamp(float(seconds)) | ||
|
||
@property | ||
def update_time(self): | ||
"""Returns the last update timestamp""" | ||
update_time = self._data.get('updateTime') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not update_time: | ||
return None | ||
|
||
seconds = update_time.get('seconds') | ||
if not isinstance(seconds, numbers.Number): | ||
return None | ||
|
||
return datetime.datetime.fromtimestamp(float(seconds)) | ||
|
||
@property | ||
def validation_error(self): | ||
return self._data.get('state', {}).get('validationError', {}).get('message') | ||
|
||
@property | ||
def published(self): | ||
return bool(self._data.get('state', {}).get('published')) | ||
|
||
@property | ||
def etag(self): | ||
return self._data.get('etag') | ||
|
||
@property | ||
def model_hash(self): | ||
return self._data.get('modelHash') | ||
|
||
@property | ||
def tags(self): | ||
return self._data.get('tags') | ||
|
||
@tags.setter | ||
def tags(self, tags): | ||
self._data['tags'] = _validate_tags(tags) | ||
return self | ||
|
||
@property | ||
def locked(self): | ||
return bool(self._data.get('activeOperations') and | ||
len(self._data.get('activeOperations')) > 0) | ||
|
||
@property | ||
def model_format(self): | ||
return self._model_format | ||
|
||
@model_format.setter | ||
def model_format(self, model_format): | ||
if model_format is not None: | ||
_validate_model_format(model_format) | ||
self._model_format = model_format #Can be None | ||
return self | ||
|
||
def as_dict(self): | ||
copy = dict(self._data) | ||
if self._model_format: | ||
copy.update(self._model_format.as_dict()) | ||
return copy | ||
|
||
|
||
class ModelFormat(object): | ||
"""Abstract base class representing a Model Format such as TFLite.""" | ||
def as_dict(self): | ||
raise NotImplementedError | ||
|
||
|
||
class TFLiteFormat(ModelFormat): | ||
"""Model format representing a TFLite model. | ||
|
||
Args: | ||
model_source: A TFLiteModelSource sub class. Specifies the details of the model source. | ||
kwargs: A set of keywords returned by an API response | ||
""" | ||
def __init__(self, model_source=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify here too by adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._data = kwargs | ||
self._model_source = None | ||
|
||
gcs_tflite_uri = self._data.pop('gcsTfliteUri', None) | ||
if gcs_tflite_uri: | ||
self._model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) | ||
|
||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#TODO(ifielker): define the rest of the Model properties etc | ||
if model_source is not None: | ||
self.model_source = model_source | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
# pylint: disable=protected-access | ||
return self._data == other._data and self._model_source == other._model_source | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def model_source(self): | ||
return self._model_source | ||
|
||
@model_source.setter | ||
def model_source(self, model_source): | ||
if model_source is not None: | ||
if not isinstance(model_source, TFLiteModelSource): | ||
raise TypeError('Model source must be a TFLiteModelSource object.') | ||
self._model_source = model_source # Can be None | ||
|
||
@property | ||
def size_bytes(self): | ||
return self._data.get('sizeBytes') | ||
|
||
def as_dict(self): | ||
copy = dict(self._data) | ||
if self._model_source: | ||
copy.update(self._model_source.as_dict()) | ||
return {'tfliteModel': copy} | ||
|
||
|
||
class TFLiteModelSource(object): | ||
"""Abstract base class representing a model source for TFLite format models.""" | ||
def as_dict(self): | ||
raise NotImplementedError | ||
|
||
|
||
class TFLiteGCSModelSource(TFLiteModelSource): | ||
"""TFLite model source representing a tflite model file stored in GCS.""" | ||
def __init__(self, gcs_tflite_uri): | ||
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def gcs_tflite_uri(self): | ||
return self._gcs_tflite_uri | ||
|
||
@gcs_tflite_uri.setter | ||
def gcs_tflite_uri(self, gcs_tflite_uri): | ||
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
|
||
def as_dict(self): | ||
return {"gcsTfliteUri": self._gcs_tflite_uri} | ||
|
||
#TODO(ifielker): implement from_saved_model etc. | ||
|
||
|
||
class ListModelsPage(object): | ||
|
@@ -105,7 +299,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): | |
@property | ||
def models(self): | ||
"""A list of Models from this page.""" | ||
return [Model(model) for model in self._list_response.get('models', [])] | ||
return [Model(**model) for model in self._list_response.get('models', [])] | ||
|
||
@property | ||
def list_filter(self): | ||
|
@@ -179,13 +373,48 @@ def __iter__(self): | |
return self | ||
|
||
|
||
def _validate_and_parse_name(name): | ||
# The resource name is added automatically from API call responses. | ||
# The only way it could be invalid is if someone tries to | ||
# create a model from a dictionary manually and does it incorrectly. | ||
matcher = _RESOURCE_NAME_PATTERN.match(name) | ||
if not matcher: | ||
raise ValueError('Model resource name format is invalid.') | ||
return matcher.group('project_id'), matcher.group('model_id') | ||
|
||
|
||
def _validate_model_id(model_id): | ||
if not isinstance(model_id, six.string_types): | ||
raise TypeError('Model ID must be a string.') | ||
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): | ||
if not _MODEL_ID_PATTERN.match(model_id): | ||
raise ValueError('Model ID format is invalid.') | ||
|
||
|
||
def _validate_display_name(display_name): | ||
if not _DISPLAY_NAME_PATTERN.match(display_name): | ||
raise ValueError('Display name format is invalid.') | ||
return display_name | ||
|
||
|
||
def _validate_tags(tags): | ||
if not isinstance(tags, list) or not \ | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all(isinstance(tag, six.string_types) for tag in tags): | ||
raise TypeError('Tags must be a list of strings.') | ||
if not all(_TAG_PATTERN.match(tag) for tag in tags): | ||
raise ValueError('Tag format is invalid.') | ||
return tags | ||
|
||
|
||
def _validate_gcs_tflite_uri(uri): | ||
# GCS Bucket naming rules are complex. The regex is not comprehensive. | ||
# See https://cloud.google.com/storage/docs/naming for full details. | ||
if not _GCS_TFLITE_URI_PATTERN.match(uri): | ||
raise ValueError('GCS TFLite URI format is invalid.') | ||
return uri | ||
|
||
def _validate_model_format(model_format): | ||
if not isinstance(model_format, ModelFormat): | ||
raise TypeError('Model format must be a ModelFormat object.') | ||
return model_format | ||
|
||
def _validate_list_filter(list_filter): | ||
if list_filter is not None: | ||
if not isinstance(list_filter, six.string_types): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove "String -" part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done