2
2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
3
4
4
import itertools
5
- from typing import Dict , List
5
+ from typing import Any , Dict , List , Tuple
6
6
7
+ import torch
7
8
from fairseq .data .dictionary import Dictionary
8
9
from fairseq .data .legacy .masked_lm_dictionary import BertDictionary
9
10
from pytext .config .component import ComponentType , create_component
10
- from pytext .data .tensorizers import TokenTensorizer , lookup_tokens
11
+ from pytext .data .tensorizers import Tensorizer , lookup_tokens
11
12
from pytext .data .tokenizers import Tokenizer , WordPieceTokenizer
12
13
from pytext .data .utils import (
13
14
BOS ,
@@ -43,35 +44,146 @@ def build_fairseq_vocab(
43
44
)
44
45
45
46
46
- class BERTTensorizer ( TokenTensorizer ):
47
+ class BERTTensorizerBase ( Tensorizer ):
47
48
"""
48
- Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
49
+ Base Tensorizer class for all BERT style models including XLM,
50
+ RoBERTa and XLM-R.
49
51
"""
50
52
51
53
__EXPANSIBLE__ = True
52
54
53
- class Config (TokenTensorizer .Config ):
54
- #: The tokenizer to use to split input text into tokens.
55
+ class Config (Tensorizer .Config ):
56
+ # BERT style models support multiple text inputs
55
57
columns : List [str ] = ["text" ]
58
+ tokenizer : Tokenizer .Config = Tokenizer .Config ()
59
+ vocab_file : str = ""
60
+ max_seq_len : int = 256
61
+
62
+ def __init__ (
63
+ self ,
64
+ columns : List [str ] = Config .columns ,
65
+ vocab : Vocabulary = None ,
66
+ tokenizer : Tokenizer = None ,
67
+ max_seq_len : int = Config .max_seq_len ,
68
+ ) -> None :
69
+ self .columns = columns
70
+ self .vocab = vocab
71
+ self .tokenizer = tokenizer
72
+ self .max_seq_len = max_seq_len
73
+ # Needed to ensure that we're not masking special tokens. By default
74
+ # we use the BOS token from the vocab. If a class has different
75
+ # behavior (eg: XLM), it needs to override this.
76
+ self .bos_token = self .vocab .bos_token
77
+
78
+ @property
79
+ def column_schema (self ):
80
+ return [(column , str ) for column in self .columns ]
81
+
82
+ def _lookup_tokens (self , text : str , seq_len : int = None ):
83
+ """
84
+ This function knows how to call lookup_tokens with the correct
85
+ settings for this model. The default behavior is to wrap the
86
+ numberized text with distinct BOS and EOS tokens. The resulting
87
+ vector would look something like this:
88
+ [BOS, token1_id, . . . tokenN_id, EOS]
89
+
90
+ The function also takes an optional seq_len parameter which is
91
+ used to customize truncation in case we have multiple text fields.
92
+ By default max_seq_len is used. It's upto the numberize function of
93
+ the class to decide how to use the seq_len param.
94
+
95
+ For example:
96
+ - In the case of sentence pair classification, we might want both
97
+ pieces of text have the same length which is half of the
98
+ max_seq_len supported by the model.
99
+ - In the case of QA, we might want to truncate the context by a
100
+ seq_len which is longer than what we use for the question.
101
+ """
102
+ return lookup_tokens (
103
+ text ,
104
+ tokenizer = self .tokenizer ,
105
+ vocab = self .vocab ,
106
+ bos_token = self .vocab .bos_token ,
107
+ eos_token = self .vocab .eos_token ,
108
+ max_seq_len = seq_len if seq_len else self .max_seq_len ,
109
+ )
110
+
111
+ def _wrap_numberized_text (
112
+ self , numberized_sentences : List [List [str ]]
113
+ ) -> List [List [str ]]:
114
+ """
115
+ If a class has a non-standard way of generating the final numberized text
116
+ (eg: BERT) then a class specific version of wrap_numberized_text function
117
+ should be implemented. This allows us to share the numberize
118
+ function across classes without having to copy paste code. The default
119
+ implementation doesnt do anything.
120
+ """
121
+ return numberized_sentences
122
+
123
+ def numberize (self , row : Dict ) -> Tuple [Any , ...]:
124
+ """
125
+ This function contains logic for converting tokens into ids based on
126
+ the specified vocab. It also outputs, for each instance, the vectors
127
+ needed to run the actual model.
128
+ """
129
+ sentences = [self ._lookup_tokens (row [column ])[0 ] for column in self .columns ]
130
+ sentences = self ._wrap_numberized_text (sentences )
131
+ seq_lens = (len (sentence ) for sentence in sentences )
132
+ segment_labels = ([i ] * seq_len for i , seq_len in enumerate (seq_lens ))
133
+ tokens = list (itertools .chain (* sentences ))
134
+ segment_labels = list (itertools .chain (* segment_labels ))
135
+ seq_len = len (tokens )
136
+ positions = list (range (seq_len ))
137
+ # tokens, segment_label, seq_len
138
+ return tokens , segment_labels , seq_len , positions
139
+
140
+ def tensorize (self , batch ) -> Tuple [torch .Tensor , ...]:
141
+ """
142
+ Convert instance level vectors into batch level tensors.
143
+ """
144
+ tokens , segment_labels , seq_lens , positions = zip (* batch )
145
+ tokens = pad_and_tensorize (tokens , self .vocab .get_pad_index ())
146
+ pad_mask = (tokens != self .vocab .get_pad_index ()).long ()
147
+ segment_labels = pad_and_tensorize (segment_labels )
148
+ positions = pad_and_tensorize (positions )
149
+ return tokens , pad_mask , segment_labels , positions
150
+
151
+ def initialize (self , vocab_builder = None , from_scratch = True ):
152
+ # vocab for BERT is already set
153
+ return
154
+ # we need yield here to make this function a generator
155
+ yield
156
+
157
+ def sort_key (self , row ):
158
+ return row [2 ]
159
+
160
+
161
+ class BERTTensorizer (BERTTensorizerBase ):
162
+ """
163
+ Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
164
+ """
165
+
166
+ __EXPANSIBLE__ = True
167
+
168
+ class Config (BERTTensorizerBase .Config ):
56
169
tokenizer : Tokenizer .Config = WordPieceTokenizer .Config ()
57
- add_bos_token : bool = True
58
- add_eos_token : bool = True
59
- bos_token : str = "[CLS]"
60
- eos_token : str = "[SEP]"
61
- pad_token : str = "[PAD]"
62
- unk_token : str = "[UNK]"
63
- mask_token : str = "[MASK]"
64
170
vocab_file : str = WordPieceTokenizer .Config ().wordpiece_vocab_path
65
171
66
172
@classmethod
67
173
def from_config (cls , config : Config , ** kwargs ):
174
+ """
175
+ from_config parses the config associated with the tensorizer and
176
+ creates both the tokenizer and the Vocabulary object. The extra arguments
177
+ passed as kwargs allow us to reuse thie function with variable number
178
+ of arguments (eg: for classes which derive from this class).
179
+ """
68
180
tokenizer = create_component (ComponentType .TOKENIZER , config .tokenizer )
69
181
special_token_replacements = {
70
- config . unk_token : UNK ,
71
- config . pad_token : PAD ,
72
- config . bos_token : BOS ,
73
- config . eos_token : EOS ,
74
- config . mask_token : MASK ,
182
+ "[UNK]" : UNK ,
183
+ "[PAD]" : PAD ,
184
+ "[CLS]" : BOS ,
185
+ "[MASK]" : MASK ,
186
+ "[SEP]" : EOS ,
75
187
}
76
188
if isinstance (tokenizer , WordPieceTokenizer ):
77
189
vocab = Vocabulary (
@@ -86,64 +198,36 @@ def from_config(cls, config: Config, **kwargs):
86
198
)
87
199
return cls (
88
200
columns = config .columns ,
201
+ vocab = vocab ,
89
202
tokenizer = tokenizer ,
90
- add_bos_token = config .add_bos_token ,
91
- add_eos_token = config .add_eos_token ,
92
- use_eos_token_for_bos = config .use_eos_token_for_bos ,
93
203
max_seq_len = config .max_seq_len ,
94
- vocab = vocab ,
95
204
** kwargs ,
96
205
)
97
206
98
- def __init__ (self , columns , ** kwargs ):
99
- super ().__init__ (text_column = None , ** kwargs )
100
- self .columns = columns
101
- # Manually initialize column_schema since we are sending None to TokenTensorizer
102
-
103
- def initialize (self , vocab_builder = None , from_scratch = True ):
104
- # vocab for BERT is already set
105
- return
106
- # we need yield here to make this function a generator
107
- yield
108
-
109
- @property
110
- def column_schema (self ):
111
- return [(column , str ) for column in self .columns ]
207
+ def __init__ (
208
+ self ,
209
+ columns : List [str ] = Config .columns ,
210
+ vocab : Vocabulary = None ,
211
+ tokenizer : Tokenizer = None ,
212
+ max_seq_len : int = Config .max_seq_len ,
213
+ ** kwargs ,
214
+ ) -> None :
215
+ super ().__init__ (
216
+ columns = columns , vocab = vocab , tokenizer = tokenizer , max_seq_len = max_seq_len
217
+ )
112
218
113
- def _lookup_tokens (self , text ):
219
+ def _lookup_tokens (self , text : str , seq_len : int = None ):
114
220
return lookup_tokens (
115
221
text ,
116
222
tokenizer = self .tokenizer ,
117
223
vocab = self .vocab ,
118
224
bos_token = None ,
119
225
eos_token = self .vocab .eos_token ,
120
- max_seq_len = self .max_seq_len ,
226
+ max_seq_len = seq_len if seq_len else self .max_seq_len ,
121
227
)
122
228
123
- def numberize (self , row ):
124
- """Tokenize, look up in vocabulary."""
125
- sentences = [self ._lookup_tokens (row [column ])[0 ] for column in self .columns ]
126
- if self .add_bos_token :
127
- bos_token = (
128
- self .vocab .eos_token
129
- if self .use_eos_token_for_bos
130
- else self .vocab .bos_token
131
- )
132
- sentences [0 ] = [self .vocab .idx [bos_token ]] + sentences [0 ]
133
- seq_lens = (len (sentence ) for sentence in sentences )
134
- segment_labels = ([i ] * seq_len for i , seq_len in enumerate (seq_lens ))
135
- tokens = list (itertools .chain (* sentences ))
136
- segment_labels = list (itertools .chain (* segment_labels ))
137
- seq_len = len (tokens )
138
- # tokens, segment_label, seq_len
139
- return tokens , segment_labels , seq_len
140
-
141
- def sort_key (self , row ):
142
- return row [2 ]
143
-
144
- def tensorize (self , batch ):
145
- tokens , segment_labels , seq_lens = zip (* batch )
146
- tokens = pad_and_tensorize (tokens , self .vocab .get_pad_index ())
147
- pad_mask = (tokens != self .vocab .get_pad_index ()).long ()
148
- segment_labels = pad_and_tensorize (segment_labels , self .vocab .get_pad_index ())
149
- return tokens , pad_mask , segment_labels
229
+ def _wrap_numberized_text (
230
+ self , numberized_sentences : List [List [str ]]
231
+ ) -> List [List [str ]]:
232
+ numberized_sentences [0 ] = [self .vocab .get_bos_index ()] + numberized_sentences [0 ]
233
+ return numberized_sentences
0 commit comments