Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Consolidate BERT, XLM and RobERTa Tensorizers #1119

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 149 additions & 65 deletions pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import itertools
from typing import Dict, List
from typing import Any, Dict, List, Tuple

import torch
from fairseq.data.dictionary import Dictionary
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from pytext.config.component import ComponentType, create_component
from pytext.data.tensorizers import TokenTensorizer, lookup_tokens
from pytext.data.tensorizers import Tensorizer, lookup_tokens
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
from pytext.data.utils import (
BOS,
Expand Down Expand Up @@ -43,35 +44,146 @@ def build_fairseq_vocab(
)


class BERTTensorizer(TokenTensorizer):
class BERTTensorizerBase(Tensorizer):
"""
Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
Base Tensorizer class for all BERT style models including XLM,
RoBERTa and XLM-R.
"""

__EXPANSIBLE__ = True

class Config(TokenTensorizer.Config):
#: The tokenizer to use to split input text into tokens.
class Config(Tensorizer.Config):
# BERT style models support multiple text inputs
columns: List[str] = ["text"]
tokenizer: Tokenizer.Config = Tokenizer.Config()
vocab_file: str = ""
max_seq_len: int = 256

def __init__(
self,
columns: List[str] = Config.columns,
vocab: Vocabulary = None,
tokenizer: Tokenizer = None,
max_seq_len: int = Config.max_seq_len,
) -> None:
self.columns = columns
self.vocab = vocab
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
# Needed to ensure that we're not masking special tokens. By default
# we use the BOS token from the vocab. If a class has different
# behavior (eg: XLM), it needs to override this.
self.bos_token = self.vocab.bos_token

@property
def column_schema(self):
return [(column, str) for column in self.columns]

def _lookup_tokens(self, text: str, seq_len: int = None):
"""
This function knows how to call lookup_tokens with the correct
settings for this model. The default behavior is to wrap the
numberized text with distinct BOS and EOS tokens. The resulting
vector would look something like this:
[BOS, token1_id, . . . tokenN_id, EOS]

The function also takes an optional seq_len parameter which is
used to customize truncation in case we have multiple text fields.
By default max_seq_len is used. It's upto the numberize function of
the class to decide how to use the seq_len param.

For example:
- In the case of sentence pair classification, we might want both
pieces of text have the same length which is half of the
max_seq_len supported by the model.
- In the case of QA, we might want to truncate the context by a
seq_len which is longer than what we use for the question.
"""
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.vocab.bos_token,
eos_token=self.vocab.eos_token,
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def _wrap_numberized_text(
self, numberized_sentences: List[List[str]]
) -> List[List[str]]:
"""
If a class has a non-standard way of generating the final numberized text
(eg: BERT) then a class specific version of wrap_numberized_text function
should be implemented. This allows us to share the numberize
function across classes without having to copy paste code. The default
implementation doesnt do anything.
"""
return numberized_sentences

def numberize(self, row: Dict) -> Tuple[Any, ...]:
"""
This function contains logic for converting tokens into ids based on
the specified vocab. It also outputs, for each instance, the vectors
needed to run the actual model.
"""
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
sentences = self._wrap_numberized_text(sentences)
seq_lens = (len(sentence) for sentence in sentences)
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(*sentences))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))
# tokens, segment_label, seq_len
return tokens, segment_labels, seq_len, positions

def tensorize(self, batch) -> Tuple[torch.Tensor, ...]:
"""
Convert instance level vectors into batch level tensors.
"""
tokens, segment_labels, seq_lens, positions = zip(*batch)
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
pad_mask = (tokens != self.vocab.get_pad_index()).long()
segment_labels = pad_and_tensorize(segment_labels)
positions = pad_and_tensorize(positions)
return tokens, pad_mask, segment_labels, positions

def initialize(self, vocab_builder=None, from_scratch=True):
# vocab for BERT is already set
return
# we need yield here to make this function a generator
yield

def sort_key(self, row):
return row[2]


class BERTTensorizer(BERTTensorizerBase):
"""
Tensorizer for BERT tasks. Works for single sentence, sentence pair, triples etc.
"""

__EXPANSIBLE__ = True

class Config(BERTTensorizerBase.Config):
tokenizer: Tokenizer.Config = WordPieceTokenizer.Config()
add_bos_token: bool = True
add_eos_token: bool = True
bos_token: str = "[CLS]"
eos_token: str = "[SEP]"
pad_token: str = "[PAD]"
unk_token: str = "[UNK]"
mask_token: str = "[MASK]"
vocab_file: str = WordPieceTokenizer.Config().wordpiece_vocab_path

@classmethod
def from_config(cls, config: Config, **kwargs):
"""
from_config parses the config associated with the tensorizer and
creates both the tokenizer and the Vocabulary object. The extra arguments
passed as kwargs allow us to reuse thie function with variable number
of arguments (eg: for classes which derive from this class).
"""
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
special_token_replacements = {
config.unk_token: UNK,
config.pad_token: PAD,
config.bos_token: BOS,
config.eos_token: EOS,
config.mask_token: MASK,
"[UNK]": UNK,
"[PAD]": PAD,
"[CLS]": BOS,
"[MASK]": MASK,
"[SEP]": EOS,
}
if isinstance(tokenizer, WordPieceTokenizer):
vocab = Vocabulary(
Expand All @@ -86,64 +198,36 @@ def from_config(cls, config: Config, **kwargs):
)
return cls(
columns=config.columns,
vocab=vocab,
tokenizer=tokenizer,
add_bos_token=config.add_bos_token,
add_eos_token=config.add_eos_token,
use_eos_token_for_bos=config.use_eos_token_for_bos,
max_seq_len=config.max_seq_len,
vocab=vocab,
**kwargs,
)

def __init__(self, columns, **kwargs):
super().__init__(text_column=None, **kwargs)
self.columns = columns
# Manually initialize column_schema since we are sending None to TokenTensorizer

def initialize(self, vocab_builder=None, from_scratch=True):
# vocab for BERT is already set
return
# we need yield here to make this function a generator
yield

@property
def column_schema(self):
return [(column, str) for column in self.columns]
def __init__(
self,
columns: List[str] = Config.columns,
vocab: Vocabulary = None,
tokenizer: Tokenizer = None,
max_seq_len: int = Config.max_seq_len,
**kwargs,
) -> None:
super().__init__(
columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len
)

def _lookup_tokens(self, text):
def _lookup_tokens(self, text: str, seq_len: int = None):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=None,
eos_token=self.vocab.eos_token,
max_seq_len=self.max_seq_len,
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def numberize(self, row):
"""Tokenize, look up in vocabulary."""
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
if self.add_bos_token:
bos_token = (
self.vocab.eos_token
if self.use_eos_token_for_bos
else self.vocab.bos_token
)
sentences[0] = [self.vocab.idx[bos_token]] + sentences[0]
seq_lens = (len(sentence) for sentence in sentences)
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(*sentences))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
# tokens, segment_label, seq_len
return tokens, segment_labels, seq_len

def sort_key(self, row):
return row[2]

def tensorize(self, batch):
tokens, segment_labels, seq_lens = zip(*batch)
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
pad_mask = (tokens != self.vocab.get_pad_index()).long()
segment_labels = pad_and_tensorize(segment_labels, self.vocab.get_pad_index())
return tokens, pad_mask, segment_labels
def _wrap_numberized_text(
self, numberized_sentences: List[List[str]]
) -> List[List[str]]:
numberized_sentences[0] = [self.vocab.get_bos_index()] + numberized_sentences[0]
return numberized_sentences
14 changes: 5 additions & 9 deletions pytext/data/packed_lm_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pytext.common.constants import Stage
from pytext.data import Batcher, Data
from pytext.data.bert_tensorizer import BERTTensorizer
from pytext.data.bert_tensorizer import BERTTensorizerBase
from pytext.data.data import RowData
from pytext.data.sources import DataSource
from pytext.data.tensorizers import Tensorizer, TokenTensorizer
Expand Down Expand Up @@ -80,10 +80,8 @@ def _parse_row(self, row):
numberize. We will simply create this in `_format_output_row`.
"""
numberized_row = self.tensorizer.numberize(row)
if isinstance(self.tensorizer, XLMTensorizer):
tokens, seq_len, segment_labels, _ = numberized_row
elif isinstance(self.tensorizer, BERTTensorizer):
tokens, segment_labels, seq_len = numberized_row
if isinstance(self.tensorizer, BERTTensorizerBase):
tokens, segment_labels, seq_len, _ = numberized_row
elif isinstance(self.tensorizer, TokenTensorizer):
tokens, seq_len, _ = numberized_row
segment_labels = []
Expand All @@ -102,11 +100,9 @@ def _format_output_row(self, tokens, segment_labels, seq_len):
In case of the XLMTensorizer, we also need to create a new positions list
which goes from 0 to seq_len.
"""
if isinstance(self.tensorizer, XLMTensorizer):
if isinstance(self.tensorizer, BERTTensorizerBase):
positions = [index for index in range(seq_len)]
return {self.tensorizer_name: (tokens, seq_len, segment_labels, positions)}
elif isinstance(self.tensorizer, BERTTensorizer):
return {self.tensorizer_name: (tokens, segment_labels, seq_len)}
return {self.tensorizer_name: (tokens, segment_labels, seq_len, positions)}
elif isinstance(self.tensorizer, TokenTensorizer):
# dummy token_ranges
return {self.tensorizer_name: (tokens, seq_len, [(-1, -1)] * seq_len)}
Expand Down
49 changes: 13 additions & 36 deletions pytext/data/roberta_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from typing import List

from pytext.config.component import ComponentType, create_component
from pytext.data.bert_tensorizer import BERTTensorizer, build_fairseq_vocab
from pytext.data.tensorizers import Tensorizer, lookup_tokens
from pytext.data.bert_tensorizer import BERTTensorizerBase, build_fairseq_vocab
from pytext.data.tokenizers import GPT2BPETokenizer, Tokenizer
from pytext.data.utils import BOS, EOS, PAD, UNK, Vocabulary
from pytext.torchscript.tensorizer import (
Expand All @@ -15,64 +14,42 @@
from pytext.torchscript.vocab import ScriptVocabulary


class RoBERTaTensorizer(BERTTensorizer):
class Config(Tensorizer.Config):
columns: List[str] = ["text"]
class RoBERTaTensorizer(BERTTensorizerBase):
class Config(BERTTensorizerBase.Config):
vocab_file: str = (
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
)
tokenizer: GPT2BPETokenizer.Config = GPT2BPETokenizer.Config()
# Make special tokens configurable so we don't need a new
# tensorizer if the model is trained with different special token
bos_token: str = "<s>"
eos_token: str = "</s>"
pad_token: str = "<pad>"
unk_token: str = "<unk>"
max_seq_len: int = 256

@classmethod
def from_config(cls, config: Config, **kwargs):
def from_config(cls, config: Config):
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
vocab = build_fairseq_vocab(
vocab_file=config.vocab_file,
special_token_replacements={
config.pad_token: PAD,
config.bos_token: BOS,
config.eos_token: EOS,
config.unk_token: UNK,
"<pad>": PAD,
"<s>": BOS,
"</s>": EOS,
"<unk>": UNK,
},
)
return cls(
columns=config.columns,
vocab=vocab,
tokenizer=tokenizer,
max_seq_len=config.max_seq_len,
vocab=vocab,
)

def __init__(
self,
columns: List[str],
tokenizer: Tokenizer = None,
columns: List[str] = Config.columns,
vocab: Vocabulary = None,
max_seq_len=256,
tokenizer: Tokenizer = None,
max_seq_len: int = Config.max_seq_len,
) -> None:
super().__init__(
columns=columns,
tokenizer=tokenizer,
add_bos_token=False,
add_eos_token=True,
max_seq_len=max_seq_len,
vocab=vocab,
)

def _lookup_tokens(self, text: str):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.vocab.bos_token,
eos_token=self.vocab.eos_token,
max_seq_len=self.max_seq_len,
columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len
)

def torchscriptify(self):
Expand Down
Loading