Skip to content

Commit 5509fa8

Browse files
borguzfacebook-github-bot
authored andcommitted
open source transformer based models - data, tensorizers and tokenizer (facebookresearch#708)
Summary: Pull Request resolved: facebookresearch#708 Open source BERTTensorizer, XLMTensorizer, WordpieceTokenizer and PackedLMData. Working up to open sourcing MaskedLM and BERT classification models. Reviewed By: rutyrinott Differential Revision: D15868227 fbshipit-source-id: 891730437b570d5b152ab2da4ea19c4aaef6bb2b
1 parent 1f047ef commit 5509fa8

File tree

7 files changed

+1695
-4
lines changed

7 files changed

+1695
-4
lines changed

pytext/data/bert_tensorizer.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
import itertools
5+
from typing import List
6+
7+
from fairseq.data.masked_lm_dictionary import BertDictionary
8+
from pytext.config.component import ComponentType, create_component
9+
from pytext.data.tensorizers import TokenTensorizer
10+
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
11+
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, Vocabulary, pad_and_tensorize
12+
13+
14+
class BERTTensorizer(TokenTensorizer):
15+
"""
16+
Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
17+
"""
18+
19+
__EXPANSIBLE__ = True
20+
21+
class Config(TokenTensorizer.Config):
22+
#: The tokenizer to use to split input text into tokens.
23+
columns: List[str] = ["text"]
24+
tokenizer: Tokenizer.Config = WordPieceTokenizer.Config()
25+
add_bos_token: bool = False
26+
add_eos_token: bool = True
27+
bos_token: str = "[CLS]"
28+
eos_token: str = "[SEP]"
29+
pad_token: str = "[PAD]"
30+
unk_token: str = "[UNK]"
31+
mask_token: str = "[MASK]"
32+
vocab_file: str = WordPieceTokenizer.Config().wordpiece_vocab_path
33+
34+
@classmethod
35+
def from_config(cls, config: Config, **kwargs):
36+
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
37+
replacements = {
38+
config.unk_token: UNK,
39+
config.pad_token: PAD,
40+
config.bos_token: BOS,
41+
config.eos_token: EOS,
42+
config.mask_token: MASK,
43+
}
44+
if isinstance(tokenizer, WordPieceTokenizer):
45+
vocab = Vocabulary(
46+
[token for token, _ in tokenizer.vocab.items()],
47+
replacements=replacements,
48+
)
49+
else:
50+
dictionary = BertDictionary.load(config.vocab_file)
51+
vocab = Vocabulary(
52+
dictionary.symbols, dictionary.count, replacements=replacements
53+
)
54+
return cls(
55+
columns=config.columns,
56+
tokenizer=tokenizer,
57+
add_bos_token=config.add_bos_token,
58+
add_eos_token=config.add_eos_token,
59+
use_eos_token_for_bos=config.use_eos_token_for_bos,
60+
max_seq_len=config.max_seq_len,
61+
vocab=vocab,
62+
**kwargs,
63+
)
64+
65+
def __init__(self, columns, **kwargs):
66+
super().__init__(text_column=None, **kwargs)
67+
self.columns = columns
68+
# Manually initialize column_schema since we are sending None to TokenTensorizer
69+
self.column_schema = [(column, str) for column in columns]
70+
71+
def numberize(self, row):
72+
"""Tokenize, look up in vocabulary."""
73+
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
74+
sentences[0] = [self.vocab.idx[BOS]] + sentences[0]
75+
seq_lens = (len(sentence) for sentence in sentences)
76+
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
77+
tokens = list(itertools.chain(*sentences))
78+
segment_labels = list(itertools.chain(*segment_labels))
79+
seq_len = len(tokens)
80+
# tokens, segment_label, seq_len
81+
return tokens, segment_labels, seq_len
82+
83+
def sort_key(self, row):
84+
return row[2]
85+
86+
def tensorize(self, batch):
87+
tokens, segment_labels, seq_lens = zip(*batch)
88+
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
89+
pad_mask = (tokens != self.vocab.get_pad_index()).long()
90+
segment_labels = pad_and_tensorize(segment_labels, self.vocab.get_pad_index())
91+
return tokens, pad_mask, segment_labels

pytext/data/packed_lm_data.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from typing import Dict, List, Optional, Type
5+
6+
from pytext.common.constants import Stage
7+
from pytext.data import Batcher, Data
8+
from pytext.data.bert_tensorizer import BERTTensorizer
9+
from pytext.data.data import RowData
10+
from pytext.data.sources import DataSource
11+
from pytext.data.tensorizers import Tensorizer, TokenTensorizer
12+
from pytext.data.xlm_tensorizer import XLMTensorizer
13+
14+
15+
class PackedLMData(Data):
16+
"""
17+
Special purpose Data object which assumes a single text tensorizer. Packs
18+
tokens into a square batch with no padding. Used for LM training. The object
19+
also takes in an optional language argument which is used for cross-lingual
20+
LM training.
21+
"""
22+
23+
__EXPANSIBLE__ = True
24+
25+
class Config(Data.Config):
26+
max_seq_len: int = 128
27+
28+
@classmethod
29+
def from_config(
30+
cls,
31+
config: Config,
32+
schema: Dict[str, Type],
33+
tensorizers: Dict[str, Tensorizer],
34+
language: Optional[str] = None,
35+
rank: int = 0,
36+
world_size: int = 1,
37+
):
38+
return super(PackedLMData, cls).from_config(
39+
config,
40+
schema,
41+
tensorizers,
42+
rank,
43+
world_size,
44+
language=language,
45+
max_seq_len=config.max_seq_len,
46+
)
47+
48+
def __init__(
49+
self,
50+
data_source: DataSource,
51+
tensorizers: Dict[str, Tensorizer],
52+
batcher: Batcher = None,
53+
max_seq_len: int = Config.max_seq_len,
54+
sort_key: Optional[str] = None,
55+
# language is used in cross-lingual LM training
56+
language: Optional[str] = None,
57+
in_memory: Optional[bool] = False,
58+
):
59+
super().__init__(data_source, tensorizers, batcher, sort_key, in_memory)
60+
assert len(list(self.tensorizers.items())) == 1
61+
self.tensorizer_name, self.tensorizer = list(self.tensorizers.items())[0]
62+
self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []}
63+
self.max_seq_len = max_seq_len
64+
self.language = language
65+
self.batch = {Stage.TRAIN: None, Stage.EVAL: None, Stage.TEST: None}
66+
67+
def _parse_row(self, row):
68+
"""
69+
The output of numberization has different number of elements depending on
70+
the tensorizer used. For example: positions tensor is only output by the
71+
XLMTensorizer. This function unpacks the elements according to the
72+
specific tensorizer used.
73+
Additionally, since we are packing tokens into fixed size
74+
blocks, we don't need to use the positions vector output by the call to
75+
numberize. We will simply create this in `_format_output_row`.
76+
"""
77+
numberized_row = self.tensorizer.numberize(row)
78+
if isinstance(self.tensorizer, XLMTensorizer):
79+
tokens, seq_len, segment_labels, _ = numberized_row
80+
elif isinstance(self.tensorizer, BERTTensorizer):
81+
tokens, segment_labels, seq_len = numberized_row
82+
elif isinstance(self.tensorizer, TokenTensorizer):
83+
tokens, seq_len, _ = numberized_row
84+
segment_labels = []
85+
else:
86+
raise NotImplementedError(
87+
"PackedLMData only supports XLMTensorizer, BERTTensorizer and "
88+
"TokenTensorizer."
89+
)
90+
return tokens, segment_labels, seq_len
91+
92+
def _format_output_row(self, tokens, segment_labels, seq_len):
93+
"""
94+
The tensorize function for different tensorizers takes in different
95+
number of inputs which may be arranged differently. This function formats
96+
the output dict to conform to the expectations of the tensorizer.
97+
In case of the XLMTensorizer, we also need to create a new positions list
98+
which goes from 0 to seq_len.
99+
"""
100+
if isinstance(self.tensorizer, XLMTensorizer):
101+
positions = [index for index in range(seq_len)]
102+
return {self.tensorizer_name: (tokens, seq_len, segment_labels, positions)}
103+
elif isinstance(self.tensorizer, BERTTensorizer):
104+
return {self.tensorizer_name: (tokens, segment_labels, seq_len)}
105+
elif isinstance(self.tensorizer, TokenTensorizer):
106+
# dummy token_ranges
107+
return {self.tensorizer_name: (tokens, seq_len, [(-1, -1)] * seq_len)}
108+
else:
109+
raise NotImplementedError(
110+
"PackedLMData only supports BERTTensorizer and TokenTensorizer."
111+
)
112+
113+
def _yield_and_reset(self):
114+
packed_tokens = list(self.remainder["tokens"])
115+
packed_segments = list(self.remainder["segment_labels"])
116+
self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []}
117+
return RowData(
118+
{}, # packed LM data doesn't respect data cardinality
119+
self._format_output_row(packed_tokens, packed_segments, len(packed_tokens)),
120+
)
121+
122+
def numberize_rows(self, rows):
123+
"""
124+
This function does the actual packing. It processes rows until we obtain
125+
a block of data with length = max_seq_len.
126+
"""
127+
for row in rows:
128+
129+
# if the packedLM object has a language member then a cross-lingual
130+
# LM is being trained using monolingual data.
131+
# Add this language to the row since the underlying
132+
# tensorizer needs this to generate language embeddings (used as
133+
# segment_labels below)
134+
if self.language:
135+
row["language"] = self.language
136+
137+
tokens, segment_labels, seq_len = self._parse_row(row)
138+
remaining = self.max_seq_len - len(self.remainder["tokens"]) - 1
139+
while remaining < len(tokens):
140+
self.remainder["tokens"].extend(tokens[:remaining])
141+
self.remainder["segment_labels"].extend(segment_labels[:remaining])
142+
tokens = tokens[remaining:]
143+
segment_labels = segment_labels[remaining:]
144+
yield self._yield_and_reset()
145+
remaining = self.max_seq_len - 1
146+
self.remainder["tokens"].extend(tokens)
147+
self.remainder["segment_labels"].extend(segment_labels)
148+
if len(self.remainder["tokens"]):
149+
yield self._yield_and_reset()

0 commit comments

Comments
 (0)