Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

changes to make tutorial code simpler #1002

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions pytext/data/xlm_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

LANG2ID_15 = {
"ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14,
}

LANG2ID_20 = {
"ar": 0,
"bn": 1,
"de": 2,
"en": 3,
"es": 4,
"fr": 5,
"hi": 6,
"id": 7,
"it": 8,
"ko": 9,
"my": 10,
"pl": 11,
"pt": 12,
"ru": 13,
"sw": 14,
"th": 15,
"tl": 16,
"tr": 17,
"vi": 18,
"zh": 19,
}


LANG2ID_43 = {
"ar": 0,
"bg": 1,
"bn": 2,
"da": 3,
"de": 4,
"el": 5,
"en": 6,
"es": 7,
"fa": 8,
"fr": 9,
"he": 10,
"hi": 11,
"hu": 12,
"id": 13,
"it": 14,
"ja": 15,
"km": 16,
"kn": 17,
"ko": 18,
"lt": 19,
"ml": 20,
"mr": 21,
"ms": 22,
"my": 23,
"nl": 24,
"pa": 25,
"pl": 26,
"ps": 27,
"pt": 28,
"ro": 29,
"ru": 30,
"si": 31,
"sv": 32,
"sw": 33,
"ta": 34,
"te": 35,
"th": 36,
"tl": 37,
"tr": 38,
"ur": 39,
"vi": 40,
"CN": 41,
"TW": 42,
}
70 changes: 27 additions & 43 deletions pytext/data/xlm_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,33 @@
from pytext.data.tensorizers import lookup_tokens
from pytext.data.tokenizers import Tokenizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, Vocabulary, pad_and_tensorize
from pytext.data.xlm_constants import LANG2ID_15
from pytext.data.xlm_dictionary import Dictionary as XLMDictionary


DEFAULT_LANG2ID_DICT = {
"ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14,
}
def read_vocab(
vocab_file: str, max_vocab: int, min_count: int
) -> Tuple[List, List, Dict]:
dictionary = XLMDictionary.read_vocab(vocab_file)
if max_vocab >= 1:
dictionary.max_vocab(max_vocab)
if min_count >= 0:
dictionary.min_count(min_count)
vocab_list = [dictionary.id2word[w] for w in sorted(dictionary.id2word)]
counts = [dictionary.counts[w] for w in vocab_list]
replacements = {"<unk>": UNK, "<pad>": PAD, "<s>": BOS, "</s>": EOS}
return vocab_list, counts, replacements


def read_fairseq_vocab(
vocab_file: str, max_vocab: int = -1, min_count: int = -1
) -> Tuple[List, List, Dict]:
dictionary = MaskedLMDictionary.load(vocab_file)
dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
vocab_list = dictionary.symbols
counts = dictionary.count
replacements = {"<pad>": PAD, "</s>": EOS, "<unk>": UNK, "<mask>": MASK}
return vocab_list, counts, replacements


class XLMTensorizer(BERTTensorizer):
Expand All @@ -48,7 +55,7 @@ class Config(BERTTensorizer.Config):
max_vocab: int = 95000
min_count: int = 0
language_columns: List[str] = ["language"]
lang2id: Dict[str, int] = DEFAULT_LANG2ID_DICT
lang2id: Dict[str, int] = LANG2ID_15
reset_positions: bool = False
has_language_in_data: bool = False
use_language_embeddings: bool = True
Expand Down Expand Up @@ -235,29 +242,6 @@ def tensorize(self, batch) -> Tuple[torch.Tensor, ...]:
positions,
)

def _read_vocab(
self, vocab_file: str, max_vocab: int, min_count: int
) -> Tuple[List, List, Dict]:
dictionary = XLMDictionary.read_vocab(vocab_file)
if max_vocab >= 1:
dictionary.max_vocab(max_vocab)
if min_count >= 0:
dictionary.min_count(min_count)
vocab_list = [dictionary.id2word[w] for w in sorted(dictionary.id2word)]
counts = [dictionary.counts[w] for w in vocab_list]
replacements = {"<unk>": UNK, "<pad>": PAD, "<s>": BOS, "</s>": EOS}
return vocab_list, counts, replacements

def _read_fairseq_vocab(
self, vocab_file: str, max_vocab: int = -1, min_count: int = -1
) -> Tuple[List, List, Dict]:
dictionary = MaskedLMDictionary.load(vocab_file)
dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
vocab_list = dictionary.symbols
counts = dictionary.count
replacements = {"<pad>": PAD, "</s>": EOS, "<unk>": UNK, "<mask>": MASK}
return vocab_list, counts, replacements

def _build_vocab(
self, vocab_file: str, max_vocab: int, min_count: int
) -> Vocabulary:
Expand All @@ -266,11 +250,11 @@ def _build_vocab(
source.
"""
if self.is_fairseq:
vocab_list, counts, replacements = self._read_fairseq_vocab(
vocab_list, counts, replacements = read_fairseq_vocab(
vocab_file, max_vocab, min_count
)
else:
vocab_list, counts, replacements = self._read_vocab(
vocab_list, counts, replacements = read_vocab(
vocab_file, max_vocab, min_count
)
return Vocabulary(vocab_list, counts, replacements=replacements)
12 changes: 10 additions & 2 deletions pytext/models/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def arrange_targets(self, tensor_dict):

def forward(self, *inputs) -> List[torch.Tensor]:
encoded_layers, _ = self.encoder(inputs)
return self.decoder(encoded_layers[-1][self.mask.bool(), :])
encoded_layer = encoded_layers[-1][self.mask.bool(), :]
if encoded_layer.nelement() == 0:
# No masked tokens, select just the first state from the first batch
encoded_layer = encoded_layers[-1][0, :1]
return self.decoder(encoded_layer)

def _select_tokens_to_mask(
self, tokens: torch.Tensor, mask_prob: float
Expand Down Expand Up @@ -179,4 +183,8 @@ def _mask_input(self, tokens, mask, replacement):
return tokens * (1 - mask) + replacement * mask

def _mask_output(self, tokens, mask):
return torch.masked_select(tokens, mask.bool())
output = torch.masked_select(tokens, mask.bool())
if output.nelement() == 0:
# No masked tokens, select just the first target from the first batch
output = tokens[0, :1]
return output