Skip to content

service: records: add MultiFieldsResolver #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions invenio_records_resources/records/systemfields/entity_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,50 @@ def __get__(self, record, owner=None):
return self.obj(record)


class MultiReferenceEntityField(ReferencedEntityField):
"""System field extending ReferencedEntityField to support object lists."""

def set_obj(self, instance, objs):
"""Set multiple referenced entities."""
references = []

for obj in objs:
if isinstance(obj, dict):
ref_dict = obj
elif isinstance(obj, EntityProxy):
ref_dict = obj.reference_dict
elif obj is not None:
ref_dict = self._registry.reference_entity(obj, raise_=True)
else:
continue

if not self._check_reference(instance, ref_dict):
raise ValueError(f"Invalid reference for '{self.key}': {ref_dict}")

references.append(ref_dict)

self.set_dictkey(instance, references)
self._set_cache(instance, None)

def obj(self, instance):
"""Get the referenced entities as a list of `EntityProxy` objects."""
cached = self._get_cache(instance)
if cached is not None:
return cached

references_list = self.get_dictkey(instance)
if references_list is None:
return []

resolved_objects = [
self._registry.resolve_entity_proxy(ref_dict)
for ref_dict in references_list
]

self._set_cache(instance, resolved_objects)
return resolved_objects


def check_allowed_references(get_allows_none, get_allowed_types, request, ref_dict):
"""Check the reference according to rules specific to requests.

Expand Down
89 changes: 79 additions & 10 deletions invenio_records_resources/services/records/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,7 @@ def _find_fields(self, service, value):
The `id` field used to match the resolved record is hardcoded,
as in the `read_many` method.
"""
fields = []
for field in self._fields:
if field.has(service, value):
fields.append(field)
return fields
return [field for field in self._fields if field.has(service, value)]

def _fetch_referenced(self, grouped_values, identity):
"""Search and fetch referenced recs by ids."""
Expand All @@ -472,11 +468,10 @@ def _add_dereferenced_record(service, value, resolved_rec):
_add_dereferenced_record(service, value, hit)

ghost_values = all_values - found_values
if ghost_values:
for value in ghost_values:
# set dereferenced record to None. That will trigger eventually
# the field.ghost_record() to be called
_add_dereferenced_record(service, value, None)
for value in ghost_values:
# set dereferenced record to None. That will trigger eventually
# the field.ghost_record() to be called
_add_dereferenced_record(service, value, None)

def resolve(self, identity, hits):
"""Collect field values and resolve referenced records."""
Expand Down Expand Up @@ -510,3 +505,77 @@ def expand(self, identity, hit):
dict_merge(results, d)

return results


class MultiFieldsResolver(FieldsResolver):
"""Resolve the reference record for each of the configured fields.

Given a list of fields referencing other records/objects,
it fetches and returns the dereferenced record/obj.

This class supports resolution of nested fields and efficiently batches
resolution requests to services.
"""

def _collect_values(self, hits):
"""Collect all field values to be expanded."""
grouped_values = dict()

for hit in hits:
for field in self._fields:
try:
value = dict_lookup(hit, field.field_name)
if value is None:
continue
except KeyError:
continue

# Ensure `get_value_service` can return multiple (v, service) tuples
values_services = field.get_value_service(value)

if not isinstance(values_services, list):
values_services = [values_services] # Ensure list format

for v, service in values_services:
field.add_service_value(service, v)
grouped_values.setdefault(service, set()).add(v)

return grouped_values

def expand(self, identity, hit):
"""Expand and return the resolved fields for the given hit."""
results = {}

for field in self._fields:
try:
value = dict_lookup(hit, field.field_name)
if value is None:
continue
except KeyError:
continue

# Ensure `get_value_service` supports lists of (value, service)
values_services = field.get_value_service(value)
resolved_recs = {}
if isinstance(values_services, list):
resolved_recs = []
for v, service in values_services:
resolved_rec = field.get_dereferenced_record(service, v)
if resolved_rec:
resolved = field.pick(identity, resolved_rec)
if isinstance(resolved, list):
resolved_recs.extend(resolved)
else:
resolved_recs.append(field.pick(identity, resolved_rec))
else:
v, service = values_services
resolved_rec = field.get_dereferenced_record(service, v)
if resolved_rec:
resolved_recs = field.pick(identity, resolved_rec)
if resolved_recs:
# Maintain nested structure
d = dict()
dict_set(d, field.field_name, resolved_recs)
dict_merge(results, d)

return results
7 changes: 5 additions & 2 deletions invenio_records_resources/services/references/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

"""Service-related things for entity references."""

from .schema import EntityReferenceBaseSchema
from .schema import EntityReferenceBaseSchema, MultipleEntityReferenceBaseSchema

__all__ = ("EntityReferenceBaseSchema",)
__all__ = (
"EntityReferenceBaseSchema",
"MultipleEntityReferenceBaseSchema",
)
27 changes: 27 additions & 0 deletions invenio_records_resources/services/references/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,30 @@ def create_from_dict(cls, allowed_types, special_fields=None):
return cls.from_dict(
{ref_type: field_types[ref_type] for ref_type in allowed_types}
)


class MultipleEntityReferenceBaseSchema(EntityReferenceBaseSchema):
"""Base schema for entity references, allowing multiple keys.

Example of an allowed value: ``{"user": 1, "record": "abcd-1234"}``.
Example of a disallowed value: ``{"user": 1}``.
"""

@classmethod
def create_from_dict(cls, allowed_types, special_fields=None):
"""Create an entity reference schema based on the allowed ref types.

Per default, a ``fields.String()`` field is registered for each of
the type names in the ``allowed_types`` list.
The field type can be customized by providing an entry in the
``special_fields`` dict, with the type name as key and the field type
as value (e.g. ``{"user": fields.Integer()}``).
"""
field_types = special_fields or {}
for ref_type in allowed_types:
# each type would be a String field per default
field_types.setdefault(ref_type, fields.String())

return cls.from_dict(
{ref_type: field_types[ref_type] for ref_type in allowed_types}
)
Loading