Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 1e4dd71

Browse files
Haoran Lifacebook-github-bot
Haoran Li
authored andcommitted
fix MultipleData by making tensorizers able to initialize from multiple data sources (#972)
Summary: Pull Request resolved: #972 For the newly added Data that could read from multiple data sources, there are issues when initializing tensorizers, tensorizers will only be initialized with the last data source, this diff makes tensorizers with vocab able to initiate from multiple data sources by introducing a new parameter in tensorizer initialize. When from_scratch is set to False for Data, it allows tensorizer to accumulate vocab from multiple data sources. I modified each tensorizer accordingly depending on its implementation, the basic change is to not create a new vocab builder when from_scratch is False. Also MultipleData was not working with run_per_key_testing, not able to update test_path and will always run on one test data multiple time. The fix is to do some special handling for multiple data and make it able to accept a given data source. Reviewed By: rutyrinott Differential Revision: D17301822 fbshipit-source-id: 68a41659f8636dae05d36c3f9b7296aae9ba6920
1 parent 5ad6a26 commit 1e4dd71

File tree

6 files changed

+109
-73
lines changed

6 files changed

+109
-73
lines changed

pytext/data/bert_tensorizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def __init__(self, columns, **kwargs):
6767
self.columns = columns
6868
# Manually initialize column_schema since we are sending None to TokenTensorizer
6969

70+
def initialize(self, vocab_builder=None, from_scratch=True):
71+
# vocab for BERT is already set
72+
return
73+
# we need yield here to make this function a generator
74+
yield
75+
7076
@property
7177
def column_schema(self):
7278
return [(column, str) for column in self.columns]

pytext/data/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def __init__(
264264
sort_key: Optional[str] = None,
265265
in_memory: Optional[bool] = False,
266266
init_tensorizers: Optional[bool] = True,
267+
init_tensorizers_from_scratch: Optional[bool] = True,
267268
):
268269
"""This function should also initialize the passed in tensorizers with
269270
metadata they need for model construction."""
@@ -280,7 +281,9 @@ def __init__(
280281
else data_source.train
281282
)
282283
if init_tensorizers:
283-
initialize_tensorizers(self.tensorizers, full_train_data)
284+
initialize_tensorizers(
285+
self.tensorizers, full_train_data, init_tensorizers_from_scratch
286+
)
284287
else:
285288
print(
286289
"Skipped initializing tensorizers since they are loaded from a "

pytext/data/disjoint_multitask_data.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def __init__(
5959
test_key: str = None,
6060
task_key: str = BatchContext.TASK_NAME,
6161
) -> None:
62-
test_key = test_key or list(data_dict)[0]
62+
self.test_key = test_key or list(data_dict)[0]
6363
# currently the way training is set up is that, the data object needs
6464
# to specify a data_source which is used at test time. For multitask
6565
# this is set to the data_source associated with the test_key
66-
self.data_source = data_dict[test_key].data_source
66+
self.data_source = data_dict[self.test_key].data_source
6767
super().__init__(self.data_source, {})
6868
self.data_dict = data_dict
6969
self.samplers = samplers
@@ -74,11 +74,16 @@ def batches(self, stage: Stage, data_source=None):
7474
"""Yield batches from each task, sampled according to a given sampler.
7575
This batcher additionally exposes a task name in the batch to allow the model
7676
to filter examples to the appropriate tasks."""
77-
all_batches = {
78-
name: task.batches(stage) for name, task in self.data_dict.items()
79-
}
80-
sampled_batches = self.samplers[stage].batchify(all_batches)
77+
if data_source is not None:
78+
# means being called in test workflow
79+
for batch in self.data_dict[self.test_key].batches(stage, data_source):
80+
yield batch
81+
else:
82+
all_batches = {
83+
name: task.batches(stage) for name, task in self.data_dict.items()
84+
}
85+
sampled_batches = self.samplers[stage].batchify(all_batches)
8186

82-
for name, (raw_batch, batch) in sampled_batches:
83-
batch[self.task_key] = name
84-
yield BatchData(raw_batch, batch)
87+
for name, (raw_batch, batch) in sampled_batches:
88+
batch[self.task_key] = name
89+
yield BatchData(raw_batch, batch)

pytext/data/squad_tensorizer.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,27 +96,29 @@ def __init__(
9696
self.answers_column = answers_column
9797
self.answer_starts_column = answer_starts_column
9898

99-
def initialize(self, vocab_builder=None):
99+
def initialize(self, vocab_builder=None, from_scratch=True):
100100
"""Build vocabulary based on training corpus."""
101-
if not self.vocab:
101+
if isinstance(self.tokenizer, WordPieceTokenizer):
102+
return
103+
if not self.vocab_builder or from_scratch:
102104
self.vocab_builder = vocab_builder or VocabBuilder()
103105
self.vocab_builder.pad_index = 0
104106
self.vocab_builder.unk_index = 1
105-
ques_initializer = self.ques_tensorizer.initialize(self.vocab_builder)
106-
doc_initializer = self.doc_tensorizer.initialize(self.vocab_builder)
107-
ques_initializer.send(None)
108-
doc_initializer.send(None)
107+
ques_initializer = self.ques_tensorizer.initialize(
108+
self.vocab_builder, from_scratch
109+
)
110+
doc_initializer = self.doc_tensorizer.initialize(
111+
self.vocab_builder, from_scratch
112+
)
113+
ques_initializer.send(None)
114+
doc_initializer.send(None)
109115
try:
110116
while True:
111-
if self.vocab:
112-
yield
113-
else:
114-
row = yield
115-
ques_initializer.send(row)
116-
doc_initializer.send(row)
117+
row = yield
118+
ques_initializer.send(row)
119+
doc_initializer.send(row)
117120
except GeneratorExit:
118-
if not self.vocab:
119-
self.vocab = self.vocab_builder.make_vocab()
121+
self.vocab = self.vocab_builder.make_vocab()
120122

121123
def _lookup_tokens(self, text, source_is_doc=True):
122124
# This is useful in SquadMetricReporter._unnumberize()

pytext/data/tensorizers.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def tensorize(self, batch):
133133
"""Tensorizer knows how to pad and tensorize a batch of it's own output."""
134134
return batch
135135

136-
def initialize(self):
136+
def initialize(self, from_scratch=True):
137137
"""
138138
The initialize function is carefully designed to allow us to read through the
139139
training dataset only once, and not store it in memory. As such, it can't itself
@@ -264,9 +264,9 @@ def _lookup_tokens(self, text=None, pre_tokenized=None):
264264
def _reverse_lookup(self, token_ids):
265265
return [self.vocab[id] for id in token_ids]
266266

267-
def initialize(self, vocab_builder=None):
267+
def initialize(self, vocab_builder=None, from_scratch=True):
268268
"""Build vocabulary based on training corpus."""
269-
if self.vocab:
269+
if self.vocab and from_scratch:
270270
if self.vocab_config.build_from_data or self.vocab_config.vocab_files:
271271
print(
272272
f"`{self.text_column}` column: vocab already provided, skipping "
@@ -279,10 +279,12 @@ def initialize(self, vocab_builder=None):
279279
f"To create token tensorizer for '{self.text_column}', either "
280280
f"`build_from_data` or `vocab_files` must be set."
281281
)
282-
283-
self.vocab_builder = vocab_builder or VocabBuilder()
284-
self.vocab_builder.use_bos = self.add_bos_token
285-
self.vocab_builder.use_eos = self.add_eos_token
282+
if not self.vocab_builder:
283+
# else means not initialize from scratch, self.vocab_builder
284+
# would be set already
285+
self.vocab_builder = vocab_builder or VocabBuilder()
286+
self.vocab_builder.use_bos = self.add_bos_token
287+
self.vocab_builder.use_eos = self.add_eos_token
286288
if not self.vocab_config.build_from_data:
287289
self._add_vocab_from_files()
288290
self.vocab = self.vocab_builder.make_vocab()
@@ -561,11 +563,11 @@ def __init__(
561563
def column_schema(self):
562564
return [(self.label_column, str)]
563565

564-
def initialize(self):
566+
def initialize(self, from_scratch=True):
565567
"""
566568
Look through the dataset for all labels and create a vocab map for them.
567569
"""
568-
if self.vocab:
570+
if self.vocab and from_scratch:
569571
return
570572
try:
571573
while True:
@@ -652,11 +654,11 @@ def _get_row_value_as_str(self, row) -> str:
652654
row_value = str(row_value.item())
653655
return row_value
654656

655-
def initialize(self):
657+
def initialize(self, from_scratch=True):
656658
"""
657659
Look through the dataset for all uids and create a vocab map for them.
658660
"""
659-
if self.vocab:
661+
if self.vocab and from_scratch:
660662
return
661663
try:
662664
while True:
@@ -881,24 +883,27 @@ def __init__(
881883
self.allow_unknown = allow_unknown
882884
self.tokenizer = tokenizer or Tokenizer()
883885
self.pad_idx = Padding.DEFAULT_LABEL_PAD_IDX
886+
self.vocab_builder = VocabBuilder()
887+
self.vocab_builder.add(NO_LABEL)
888+
self.vocab_builder.use_pad = False
889+
self.vocab_builder.use_unk = self.allow_unknown
890+
self.vocab = None
884891

885892
@property
886893
def column_schema(self):
887894
return [(self.text_column, str), (self.slot_column, List[Slot])]
888895

889-
def initialize(self):
896+
def initialize(self, from_scratch=True):
890897
"""Look through the dataset for all labels and create a vocab map for them."""
891-
builder = VocabBuilder()
892-
builder.add(NO_LABEL)
893-
builder.use_pad = False
894-
builder.use_unk = self.allow_unknown
898+
if self.vocab and from_scratch:
899+
return
895900
try:
896901
while True:
897902
row = yield
898903
slots = row[self.slot_column]
899-
builder.add_all(s.label for s in slots)
904+
self.vocab_builder.add_all(s.label for s in slots)
900905
except GeneratorExit:
901-
self.vocab = builder.make_vocab()
906+
self.vocab = self.vocab_builder.make_vocab()
902907

903908
def numberize(self, row):
904909
"""
@@ -993,23 +998,26 @@ def __init__(
993998
self.text_column = text_column
994999
self.dict_column = dict_column
9951000
self.tokenizer = tokenizer or Tokenizer()
1001+
self.vocab_builder = VocabBuilder()
1002+
self.vocab = None
9961003

9971004
@property
9981005
def column_schema(self):
9991006
return [(self.text_column, str), (self.dict_column, Gazetteer)]
10001007

1001-
def initialize(self):
1008+
def initialize(self, from_scratch=True):
10021009
"""
10031010
Look through the dataset for all dict features to create vocab.
10041011
"""
1005-
builder = VocabBuilder()
1012+
if self.vocab and from_scratch:
1013+
return
10061014
try:
10071015
while True:
10081016
row = yield
10091017
for token_dict in row[self.dict_column]:
1010-
builder.add_all(token_dict["features"])
1018+
self.vocab_builder.add_all(token_dict["features"])
10111019
except GeneratorExit:
1012-
self.vocab = builder.make_vocab()
1020+
self.vocab = self.vocab_builder.make_vocab()
10131021

10141022
def numberize(self, row):
10151023
"""
@@ -1169,6 +1177,7 @@ def __init__(
11691177
self.column = column
11701178
self.tokenizer = tokenizer or Tokenizer()
11711179
self.vocab = vocab
1180+
self.vocab_builder = None
11721181
self.add_bos_token = add_bos_token
11731182
self.add_eos_token = add_eos_token
11741183
self.use_eos_token_for_bos = use_eos_token_for_bos
@@ -1181,24 +1190,25 @@ def __init__(
11811190
def column_schema(self):
11821191
return [(self.column, List[str])]
11831192

1184-
def initialize(self, vocab_builder=None):
1193+
def initialize(self, vocab_builder=None, from_scratch=True):
11851194
"""Build vocabulary based on training corpus."""
1186-
if self.vocab:
1195+
if self.vocab and from_scratch:
11871196
return
1188-
vocab_builder = vocab_builder or VocabBuilder()
1189-
vocab_builder.use_bos = self.add_bos_token
1190-
vocab_builder.use_eos = self.add_eos_token
1191-
vocab_builder.use_bol = self.add_bol_token
1192-
vocab_builder.use_eol = self.add_eol_token
1197+
if not self.vocab_builder:
1198+
self.vocab_builder = vocab_builder or VocabBuilder()
1199+
self.vocab_builder.use_bos = self.add_bos_token
1200+
self.vocab_builder.use_eos = self.add_eos_token
1201+
self.vocab_builder.use_bol = self.add_bol_token
1202+
self.vocab_builder.use_eol = self.add_eol_token
11931203

11941204
try:
11951205
while True:
11961206
row = yield
11971207
for raw_text in row[self.column]:
11981208
tokenized = self.tokenizer.tokenize(raw_text)
1199-
vocab_builder.add_all([t.value for t in tokenized])
1209+
self.vocab_builder.add_all([t.value for t in tokenized])
12001210
except GeneratorExit:
1201-
self.vocab = vocab_builder.make_vocab()
1211+
self.vocab = self.vocab_builder.make_vocab()
12021212

12031213
_lookup_tokens = TokenTensorizer._lookup_tokens
12041214
_tokenize = TokenTensorizer._tokenize
@@ -1274,27 +1284,29 @@ def from_config(cls, config: Config):
12741284
def __init__(self, column: str = Config.column, vocab=None):
12751285
self.column = column
12761286
self.vocab = vocab
1287+
self.vocab_builder = None
12771288

12781289
@property
12791290
def column_schema(self):
12801291
return [(self.column, List[str])]
12811292

1282-
def initialize(self, vocab_builder=None):
1293+
def initialize(self, vocab_builder=None, from_scratch=True):
12831294
"""Build vocabulary based on training corpus."""
1284-
if self.vocab:
1295+
if self.vocab and from_scratch:
12851296
return
1286-
vocab_builder = vocab_builder or VocabBuilder()
1287-
vocab_builder.use_unk = False
1288-
vocab_builder.use_pad = False
1297+
if not self.vocab_builder:
1298+
self.vocab_builder = vocab_builder or VocabBuilder()
1299+
self.vocab_builder.use_unk = False
1300+
self.vocab_builder.use_pad = False
12891301

12901302
try:
12911303
while True:
12921304
row = yield
12931305
annotation = Annotation(row[self.column])
12941306
actions = annotation.tree.to_actions()
1295-
vocab_builder.add_all(actions)
1307+
self.vocab_builder.add_all(actions)
12961308
except GeneratorExit:
1297-
self.vocab = vocab_builder.make_vocab()
1309+
self.vocab = self.vocab_builder.make_vocab()
12981310
self.shift_idx = self.vocab.idx[SHIFT]
12991311
self.reduce_idx = self.vocab.idx[REDUCE]
13001312

@@ -1378,11 +1390,16 @@ def tensorize(self, batch):
13781390
return cuda.tensor(batch, torch.float)
13791391

13801392

1381-
def initialize_tensorizers(tensorizers, data_source):
1393+
def initialize_tensorizers(tensorizers, data_source, from_scratch=True):
13821394
"""A utility function to stream a data source to the initialize functions
13831395
of a dict of tensorizers."""
13841396
initializers = []
1385-
for init in [tensorizer.initialize() for tensorizer in tensorizers.values()]:
1397+
for init in [
1398+
tensorizer.initialize(from_scratch=from_scratch)
1399+
if hasattr(tensorizer, "vocab")
1400+
else tensorizer.initialize()
1401+
for tensorizer in tensorizers.values()
1402+
]:
13861403
try:
13871404
init.send(None) # kick
13881405
initializers.append(init)

pytext/workflow.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,7 @@ def test_model_from_snapshot_path(
227227

228228
if isinstance(task, (NewTask, NewDisjointMultitask)):
229229
data_source = _get_data_source(
230-
test_path,
231-
getattr(train_config.task.data, "source", None),
232-
field_names,
233-
task,
230+
test_path, train_config.task.data, field_names, task
234231
)
235232
test_results = task.test(data_source)
236233
else:
@@ -240,7 +237,16 @@ def test_model_from_snapshot_path(
240237
return test_results, test_out_path, metric_channels
241238

242239

243-
def _get_data_source(test_path, source_config, field_names, task):
240+
def _get_data_source(test_path, data_config, field_names, task):
241+
if hasattr(data_config, "data_dict_config"):
242+
# it's multiple data
243+
if data_config.test_key:
244+
source_config = data_config.data_dict_config[data_config.test_key].source
245+
else:
246+
source_config = next(iter(data_config.data_dict_config.values())).source
247+
else:
248+
source_config = getattr(data_config, "source", None)
249+
244250
if isinstance(task, NewDisjointMultitask):
245251
# Cannot easily specify a single data source for multitask
246252
assert not test_path
@@ -277,10 +283,7 @@ def get_logits(
277283
if isinstance(task, NewTask):
278284
task.model.eval()
279285
data_source = _get_data_source(
280-
test_path,
281-
getattr(train_config.task.data, "source", None),
282-
field_names,
283-
task,
286+
test_path, train_config.task.data, field_names, task
284287
)
285288
task.data.batcher = Batcher()
286289
task.data.sort_key = None

0 commit comments

Comments
 (0)