Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Add support for exporting all PyText models that use contextual embeddings #468

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytext/config/field_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CharFeatConfig(ModuleConfig):
class PretrainedModelEmbeddingConfig(ConfigBase):
embed_dim: int = 0
model_paths: Optional[Dict[str, str]] = None
export_input_names: List[str] = ["pretrained_embeds"]
export_input_names: List[str] = ["contextual_token_embedding"]


class FloatVectorConfig(ConfigBase):
Expand Down
1 change: 0 additions & 1 deletion pytext/exporters/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def _add_feature_lengths(cls, input_names: List[str], dummy_model_input: List):
"""If any of the input_names have tokens or seq_tokens, add the length
of those tokens to dummy_input
"""

if "tokens_vals" in input_names:
dummy_model_input.append(
torch.tensor([1, 1], dtype=torch.long)
Expand Down
25 changes: 19 additions & 6 deletions pytext/models/embeddings/pretrained_model_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
from pytext.config.field_config import PretrainedModelEmbeddingConfig

Expand All @@ -17,11 +16,25 @@ def from_config(cls, config: PretrainedModelEmbeddingConfig, *args, **kwargs):
return cls(config.embed_dim)

def forward(self, embedding: torch.Tensor) -> torch.Tensor:
if embedding.shape[1] % self.embedding_dim != 0:
embedding_shape = torch.onnx.operators.shape_as_tensor(embedding)

# Since embeddings vector is flattened, verify its shape correctness.
if embedding_shape[1].item() % self.embedding_dim != 0:
raise ValueError(
f"Input embedding_dim {embedding.shape[1]} is not a"
f"Input embedding_dim {embedding_shape[1]} is not a"
+ f" multiple of specified embedding_dim {self.embedding_dim}"
)
num_tokens = embedding.shape[1] // self.embedding_dim
unflattened_embedding = embedding.view(-1, num_tokens, self.embedding_dim)
return unflattened_embedding

# Unflatten embedding Tensor from (batch_size, seq_len * embedding_size)
# to (batch_size, seq_len, embedding_size).
num_tokens = embedding_shape[1] // self.embedding_dim
new_embedding_shape = torch.cat(
(
torch.LongTensor([-1]),
num_tokens.view(1),
torch.LongTensor([self.embedding_dim]),
)
)
return torch.onnx.operators.reshape_from_tensor_shape(
embedding, new_embedding_shape
)