Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Make dict_embedding Torchscript friendly #1240

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions pytext/models/embeddings/dict_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from pytext.data.tensorizers import Tensorizer
from pytext.data.utils import PAD_INDEX, UNK_INDEX, Vocabulary
from pytext.fields import FieldMeta
from pytext.utils import cuda

from .embedding_base import EmbeddingBase


class DictEmbedding(EmbeddingBase, nn.Embedding):
class DictEmbedding(EmbeddingBase):
"""
Module for dictionary feature embeddings for tokens. Dictionary features are
also known as gazetteer features. These are per token discrete features that
Expand Down Expand Up @@ -102,13 +101,15 @@ def __init__(
unk_index: int = UNK_INDEX,
mobile: bool = False,
) -> None:
self.pad_index = pad_index
super().__init__(embed_dim)
self.unk_index = unk_index
EmbeddingBase.__init__(self, embed_dim)
nn.Embedding.__init__(
self, num_embeddings, embed_dim, padding_idx=self.pad_index
self.pad_index = pad_index
self.embedding = nn.Embedding(
num_embeddings, embed_dim, padding_idx=self.pad_index
)
self.pooling_type = pooling_type
# Temporary workaround till https://github.com/pytorch/pytorch/issues/32840
# is resolved
self.pooling_type = str(pooling_type)
self.mobile = mobile

def find_and_replace(
Expand All @@ -124,7 +125,7 @@ def find_and_replace(
else:
return torch.where(
tensor == find_val,
cuda.GetTensor(torch.full_like(tensor, replace_val)),
torch.full_like(tensor, replace_val, device=tensor.device),
tensor,
)

Expand Down Expand Up @@ -157,23 +158,25 @@ def forward(
# convert all unk indices to pad indices
feats = self.find_and_replace(feats, self.unk_index, self.pad_index)

dict_emb = super().forward(feats)
dict_emb = self.embedding(feats)

# Calculate weighted average of the embeddings
weighted_embds = dict_emb * weights.unsqueeze(2)
new_emb_shape = torch.cat(
(
batch_size.view(1),
max_toks.view(1),
torch.LongTensor([-1]),
torch.LongTensor([weighted_embds.size()[-1]]),
torch.tensor([-1]).long(),
torch.tensor([weighted_embds.size()[-1]]).long(),
)
)
weighted_embds = torch.onnx.operators.reshape_from_tensor_shape(
weighted_embds, new_emb_shape
)

if self.pooling_type == PoolingType.MEAN:
# Temporary workaround till https://github.com/pytorch/pytorch/issues/32840
# is resolved
if self.pooling_type == "mean":
reduced_embeds = (
torch.sum(weighted_embds, dim=2) / lengths.unsqueeze(2).float()
)
Expand Down
4 changes: 2 additions & 2 deletions pytext/models/test/dict_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_basic(self):
embed_dim=output_dim,
pooling_type=PoolingType.MEAN,
)
self.assertEqual(embedding_module.weight.size(0), num_embeddings)
self.assertEqual(embedding_module.weight.size(1), output_dim)
self.assertEqual(embedding_module.embedding.weight.size(0), num_embeddings)
self.assertEqual(embedding_module.embedding.weight.size(1), output_dim)

# The first and last tokens should be mapped to the zero vector.
# This is due to the invariant that both unk and pad are considered
Expand Down