Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 06f528c

Browse files
Kartikay Khandelwalfacebook-github-bot
authored andcommitted
Consolidate BERT, XLM and RobERTa Tensorizers
Summary: In this diff I take a fast stab at consolidating the XLM, BERT and RoBERTa Tensorizers. I kill a bunch of dead code and simiplify a lot. - I create a BERTTensorizerBase class which derives from Tensorizer and not TokenTensorizer since this makes the logic a lot easier especially since we no longer have to deal with all the bos, eos flags. Given that tokenize and lookup_tokens are not part of TokenTensorizer, I think this formulation makes a lot of sense. - As per suggestions, I derive the config classes from Tensorizer as well and kill all of the special flags. - I try to put as much of the functionality in the base class as possible in order to minimize copy paste code. There is still some but I dont want perfect to be the enemy of better. - I kill TLM - long live TLM. - I (temporaarily) kill support for OSS XLM which probably should have its own tensorizer anyways since it has nothing to do with transformer_sentence_encoder. Reviewed By: rutyrinott Differential Revision: D18290264 fbshipit-source-id: eecd35958d44dc2f37dd27099e86451565b2ce3b
1 parent 322fc47 commit 06f528c

12 files changed

+299
-349
lines changed

pytext/data/bert_tensorizer.py

Lines changed: 149 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

44
import itertools
5-
from typing import Dict, List
5+
from typing import Any, Dict, List, Tuple
66

7+
import torch
78
from fairseq.data.dictionary import Dictionary
89
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
910
from pytext.config.component import ComponentType, create_component
10-
from pytext.data.tensorizers import TokenTensorizer, lookup_tokens
11+
from pytext.data.tensorizers import Tensorizer, lookup_tokens
1112
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
1213
from pytext.data.utils import (
1314
BOS,
@@ -43,35 +44,146 @@ def build_fairseq_vocab(
4344
)
4445

4546

46-
class BERTTensorizer(TokenTensorizer):
47+
class BERTTensorizerBase(Tensorizer):
4748
"""
48-
Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
49+
Base Tensorizer class for all BERT style models including XLM,
50+
RoBERTa and XLM-R.
4951
"""
5052

5153
__EXPANSIBLE__ = True
5254

53-
class Config(TokenTensorizer.Config):
54-
#: The tokenizer to use to split input text into tokens.
55+
class Config(Tensorizer.Config):
56+
# BERT style models support multiple text inputs
5557
columns: List[str] = ["text"]
58+
tokenizer: Tokenizer.Config = Tokenizer.Config()
59+
vocab_file: str = ""
60+
max_seq_len: int = 256
61+
62+
def __init__(
63+
self,
64+
columns: List[str] = Config.columns,
65+
vocab: Vocabulary = None,
66+
tokenizer: Tokenizer = None,
67+
max_seq_len: int = Config.max_seq_len,
68+
) -> None:
69+
self.columns = columns
70+
self.vocab = vocab
71+
self.tokenizer = tokenizer
72+
self.max_seq_len = max_seq_len
73+
# Needed to ensure that we're not masking special tokens. By default
74+
# we use the BOS token from the vocab. If a class has different
75+
# behavior (eg: XLM), it needs to override this.
76+
self.bos_token = self.vocab.bos_token
77+
78+
@property
79+
def column_schema(self):
80+
return [(column, str) for column in self.columns]
81+
82+
def _lookup_tokens(self, text: str, seq_len: int = None):
83+
"""
84+
This function knows how to call lookup_tokens with the correct
85+
settings for this model. The default behavior is to wrap the
86+
numberized text with distinct BOS and EOS tokens. The resulting
87+
vector would look something like this:
88+
[BOS, token1_id, . . . tokenN_id, EOS]
89+
90+
The function also takes an optional seq_len parameter which is
91+
used to customize truncation in case we have multiple text fields.
92+
By default max_seq_len is used. It's upto the numberize function of
93+
the class to decide how to use the seq_len param.
94+
95+
For example:
96+
- In the case of sentence pair classification, we might want both
97+
pieces of text have the same length which is half of the
98+
max_seq_len supported by the model.
99+
- In the case of QA, we might want to truncate the context by a
100+
seq_len which is longer than what we use for the question.
101+
"""
102+
return lookup_tokens(
103+
text,
104+
tokenizer=self.tokenizer,
105+
vocab=self.vocab,
106+
bos_token=self.vocab.bos_token,
107+
eos_token=self.vocab.eos_token,
108+
max_seq_len=seq_len if seq_len else self.max_seq_len,
109+
)
110+
111+
def _wrap_numberized_text(
112+
self, numberized_sentences: List[List[str]]
113+
) -> List[List[str]]:
114+
"""
115+
If a class has a non-standard way of generating the final numberized text
116+
(eg: BERT) then a class specific version of wrap_numberized_text function
117+
should be implemented. This allows us to share the numberize
118+
function across classes without having to copy paste code. The default
119+
implementation doesnt do anything.
120+
"""
121+
return numberized_sentences
122+
123+
def numberize(self, row: Dict) -> Tuple[Any, ...]:
124+
"""
125+
This function contains logic for converting tokens into ids based on
126+
the specified vocab. It also outputs, for each instance, the vectors
127+
needed to run the actual model.
128+
"""
129+
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
130+
sentences = self._wrap_numberized_text(sentences)
131+
seq_lens = (len(sentence) for sentence in sentences)
132+
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
133+
tokens = list(itertools.chain(*sentences))
134+
segment_labels = list(itertools.chain(*segment_labels))
135+
seq_len = len(tokens)
136+
positions = list(range(seq_len))
137+
# tokens, segment_label, seq_len
138+
return tokens, segment_labels, seq_len, positions
139+
140+
def tensorize(self, batch) -> Tuple[torch.Tensor, ...]:
141+
"""
142+
Convert instance level vectors into batch level tensors.
143+
"""
144+
tokens, segment_labels, seq_lens, positions = zip(*batch)
145+
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
146+
pad_mask = (tokens != self.vocab.get_pad_index()).long()
147+
segment_labels = pad_and_tensorize(segment_labels)
148+
positions = pad_and_tensorize(positions)
149+
return tokens, pad_mask, segment_labels, positions
150+
151+
def initialize(self, vocab_builder=None, from_scratch=True):
152+
# vocab for BERT is already set
153+
return
154+
# we need yield here to make this function a generator
155+
yield
156+
157+
def sort_key(self, row):
158+
return row[2]
159+
160+
161+
class BERTTensorizer(BERTTensorizerBase):
162+
"""
163+
Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
164+
"""
165+
166+
__EXPANSIBLE__ = True
167+
168+
class Config(BERTTensorizerBase.Config):
56169
tokenizer: Tokenizer.Config = WordPieceTokenizer.Config()
57-
add_bos_token: bool = True
58-
add_eos_token: bool = True
59-
bos_token: str = "[CLS]"
60-
eos_token: str = "[SEP]"
61-
pad_token: str = "[PAD]"
62-
unk_token: str = "[UNK]"
63-
mask_token: str = "[MASK]"
64170
vocab_file: str = WordPieceTokenizer.Config().wordpiece_vocab_path
65171

66172
@classmethod
67173
def from_config(cls, config: Config, **kwargs):
174+
"""
175+
from_config parses the config associated with the tensorizer and
176+
creates both the tokenizer and the Vocabulary object. The extra arguments
177+
passed as kwargs allow us to reuse thie function with variable number
178+
of arguments (eg: for classes which derive from this class).
179+
"""
68180
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
69181
special_token_replacements = {
70-
config.unk_token: UNK,
71-
config.pad_token: PAD,
72-
config.bos_token: BOS,
73-
config.eos_token: EOS,
74-
config.mask_token: MASK,
182+
"[UNK]": UNK,
183+
"[PAD]": PAD,
184+
"[CLS]": BOS,
185+
"[MASK]": MASK,
186+
"[SEP]": EOS,
75187
}
76188
if isinstance(tokenizer, WordPieceTokenizer):
77189
vocab = Vocabulary(
@@ -86,64 +198,36 @@ def from_config(cls, config: Config, **kwargs):
86198
)
87199
return cls(
88200
columns=config.columns,
201+
vocab=vocab,
89202
tokenizer=tokenizer,
90-
add_bos_token=config.add_bos_token,
91-
add_eos_token=config.add_eos_token,
92-
use_eos_token_for_bos=config.use_eos_token_for_bos,
93203
max_seq_len=config.max_seq_len,
94-
vocab=vocab,
95204
**kwargs,
96205
)
97206

98-
def __init__(self, columns, **kwargs):
99-
super().__init__(text_column=None, **kwargs)
100-
self.columns = columns
101-
# Manually initialize column_schema since we are sending None to TokenTensorizer
102-
103-
def initialize(self, vocab_builder=None, from_scratch=True):
104-
# vocab for BERT is already set
105-
return
106-
# we need yield here to make this function a generator
107-
yield
108-
109-
@property
110-
def column_schema(self):
111-
return [(column, str) for column in self.columns]
207+
def __init__(
208+
self,
209+
columns: List[str] = Config.columns,
210+
vocab: Vocabulary = None,
211+
tokenizer: Tokenizer = None,
212+
max_seq_len: int = Config.max_seq_len,
213+
**kwargs,
214+
) -> None:
215+
super().__init__(
216+
columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len
217+
)
112218

113-
def _lookup_tokens(self, text):
219+
def _lookup_tokens(self, text: str, seq_len: int = None):
114220
return lookup_tokens(
115221
text,
116222
tokenizer=self.tokenizer,
117223
vocab=self.vocab,
118224
bos_token=None,
119225
eos_token=self.vocab.eos_token,
120-
max_seq_len=self.max_seq_len,
226+
max_seq_len=seq_len if seq_len else self.max_seq_len,
121227
)
122228

123-
def numberize(self, row):
124-
"""Tokenize, look up in vocabulary."""
125-
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
126-
if self.add_bos_token:
127-
bos_token = (
128-
self.vocab.eos_token
129-
if self.use_eos_token_for_bos
130-
else self.vocab.bos_token
131-
)
132-
sentences[0] = [self.vocab.idx[bos_token]] + sentences[0]
133-
seq_lens = (len(sentence) for sentence in sentences)
134-
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
135-
tokens = list(itertools.chain(*sentences))
136-
segment_labels = list(itertools.chain(*segment_labels))
137-
seq_len = len(tokens)
138-
# tokens, segment_label, seq_len
139-
return tokens, segment_labels, seq_len
140-
141-
def sort_key(self, row):
142-
return row[2]
143-
144-
def tensorize(self, batch):
145-
tokens, segment_labels, seq_lens = zip(*batch)
146-
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
147-
pad_mask = (tokens != self.vocab.get_pad_index()).long()
148-
segment_labels = pad_and_tensorize(segment_labels, self.vocab.get_pad_index())
149-
return tokens, pad_mask, segment_labels
229+
def _wrap_numberized_text(
230+
self, numberized_sentences: List[List[str]]
231+
) -> List[List[str]]:
232+
numberized_sentences[0] = [self.vocab.get_bos_index()] + numberized_sentences[0]
233+
return numberized_sentences

pytext/data/packed_lm_data.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pytext.common.constants import Stage
77
from pytext.data import Batcher, Data
8-
from pytext.data.bert_tensorizer import BERTTensorizer
8+
from pytext.data.bert_tensorizer import BERTTensorizerBase
99
from pytext.data.data import RowData
1010
from pytext.data.sources import DataSource
1111
from pytext.data.tensorizers import Tensorizer, TokenTensorizer
@@ -80,10 +80,8 @@ def _parse_row(self, row):
8080
numberize. We will simply create this in `_format_output_row`.
8181
"""
8282
numberized_row = self.tensorizer.numberize(row)
83-
if isinstance(self.tensorizer, XLMTensorizer):
84-
tokens, seq_len, segment_labels, _ = numberized_row
85-
elif isinstance(self.tensorizer, BERTTensorizer):
86-
tokens, segment_labels, seq_len = numberized_row
83+
if isinstance(self.tensorizer, BERTTensorizerBase):
84+
tokens, segment_labels, seq_len, _ = numberized_row
8785
elif isinstance(self.tensorizer, TokenTensorizer):
8886
tokens, seq_len, _ = numberized_row
8987
segment_labels = []
@@ -102,11 +100,9 @@ def _format_output_row(self, tokens, segment_labels, seq_len):
102100
In case of the XLMTensorizer, we also need to create a new positions list
103101
which goes from 0 to seq_len.
104102
"""
105-
if isinstance(self.tensorizer, XLMTensorizer):
103+
if isinstance(self.tensorizer, BERTTensorizerBase):
106104
positions = [index for index in range(seq_len)]
107-
return {self.tensorizer_name: (tokens, seq_len, segment_labels, positions)}
108-
elif isinstance(self.tensorizer, BERTTensorizer):
109-
return {self.tensorizer_name: (tokens, segment_labels, seq_len)}
105+
return {self.tensorizer_name: (tokens, segment_labels, seq_len, positions)}
110106
elif isinstance(self.tensorizer, TokenTensorizer):
111107
# dummy token_ranges
112108
return {self.tensorizer_name: (tokens, seq_len, [(-1, -1)] * seq_len)}

pytext/data/roberta_tensorizer.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from typing import List
55

66
from pytext.config.component import ComponentType, create_component
7-
from pytext.data.bert_tensorizer import BERTTensorizer, build_fairseq_vocab
8-
from pytext.data.tensorizers import Tensorizer, lookup_tokens
7+
from pytext.data.bert_tensorizer import BERTTensorizerBase, build_fairseq_vocab
98
from pytext.data.tokenizers import GPT2BPETokenizer, Tokenizer
109
from pytext.data.utils import BOS, EOS, PAD, UNK, Vocabulary
1110
from pytext.torchscript.tensorizer import (
@@ -15,64 +14,42 @@
1514
from pytext.torchscript.vocab import ScriptVocabulary
1615

1716

18-
class RoBERTaTensorizer(BERTTensorizer):
19-
class Config(Tensorizer.Config):
20-
columns: List[str] = ["text"]
17+
class RoBERTaTensorizer(BERTTensorizerBase):
18+
class Config(BERTTensorizerBase.Config):
2119
vocab_file: str = (
2220
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
2321
)
2422
tokenizer: GPT2BPETokenizer.Config = GPT2BPETokenizer.Config()
25-
# Make special tokens configurable so we don't need a new
26-
# tensorizer if the model is trained with different special token
27-
bos_token: str = "<s>"
28-
eos_token: str = "</s>"
29-
pad_token: str = "<pad>"
30-
unk_token: str = "<unk>"
3123
max_seq_len: int = 256
3224

3325
@classmethod
34-
def from_config(cls, config: Config, **kwargs):
26+
def from_config(cls, config: Config):
3527
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
3628
vocab = build_fairseq_vocab(
3729
vocab_file=config.vocab_file,
3830
special_token_replacements={
39-
config.pad_token: PAD,
40-
config.bos_token: BOS,
41-
config.eos_token: EOS,
42-
config.unk_token: UNK,
31+
"<pad>": PAD,
32+
"<s>": BOS,
33+
"</s>": EOS,
34+
"<unk>": UNK,
4335
},
4436
)
4537
return cls(
4638
columns=config.columns,
39+
vocab=vocab,
4740
tokenizer=tokenizer,
4841
max_seq_len=config.max_seq_len,
49-
vocab=vocab,
5042
)
5143

5244
def __init__(
5345
self,
54-
columns: List[str],
55-
tokenizer: Tokenizer = None,
46+
columns: List[str] = Config.columns,
5647
vocab: Vocabulary = None,
57-
max_seq_len=256,
48+
tokenizer: Tokenizer = None,
49+
max_seq_len: int = Config.max_seq_len,
5850
) -> None:
5951
super().__init__(
60-
columns=columns,
61-
tokenizer=tokenizer,
62-
add_bos_token=False,
63-
add_eos_token=True,
64-
max_seq_len=max_seq_len,
65-
vocab=vocab,
66-
)
67-
68-
def _lookup_tokens(self, text: str):
69-
return lookup_tokens(
70-
text,
71-
tokenizer=self.tokenizer,
72-
vocab=self.vocab,
73-
bos_token=self.vocab.bos_token,
74-
eos_token=self.vocab.eos_token,
75-
max_seq_len=self.max_seq_len,
52+
columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len
7653
)
7754

7855
def torchscriptify(self):

0 commit comments

Comments
 (0)