Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

open source Roberta #1032

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from pytext.config.component import ComponentType, create_component
from pytext.data.tensorizers import TokenTensorizer, lookup_tokens
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
from pytext.data.tensorizers import Tensorizer, TokenTensorizer, lookup_tokens
from pytext.data.tokenizers import Gpt2Tokenizer, Tokenizer, WordPieceTokenizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, Vocabulary, pad_and_tensorize


Expand Down Expand Up @@ -83,15 +83,19 @@ def _lookup_tokens(self, text):
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=None,
eos_token=EOS,
eos_token=self.vocab.eos_token,
max_seq_len=self.max_seq_len,
)

def numberize(self, row):
"""Tokenize, look up in vocabulary."""
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
if self.add_bos_token:
bos_token = EOS if self.use_eos_token_for_bos else BOS
bos_token = (
self.vocab.eos_token
if self.use_eos_token_for_bos
else self.vocab.bos_token
)
sentences[0] = [self.vocab.idx[bos_token]] + sentences[0]
seq_lens = (len(sentence) for sentence in sentences)
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
Expand All @@ -110,3 +114,43 @@ def tensorize(self, batch):
pad_mask = (tokens != self.vocab.get_pad_index()).long()
segment_labels = pad_and_tensorize(segment_labels, self.vocab.get_pad_index())
return tokens, pad_mask, segment_labels


class RoBERTaTensorizer(BERTTensorizer):
class Config(Tensorizer.Config):
columns: List[str] = ["text"]
tokenizer: Gpt2Tokenizer.Config = Gpt2Tokenizer.Config()

@classmethod
def from_config(cls, config: Config, **kwargs):
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
vocab = tokenizer.vocab
return cls(
columns=config.columns,
tokenizer=tokenizer,
max_seq_len=config.max_seq_len,
vocab=vocab,
)

def __init__(self, columns, tokenizer=None, vocab=None, max_seq_len=256):
super().__init__(
columns=columns,
tokenizer=tokenizer,
add_bos_token=False,
add_eos_token=True,
max_seq_len=max_seq_len,
vocab=vocab,
)
self.bpe = self.tokenizer.bpe
self.bos = self.tokenizer.bos
self.eos = self.tokenizer.eos

def _lookup_tokens(self, text):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.bos,
eos_token=self.eos,
max_seq_len=self.max_seq_len,
)
51 changes: 48 additions & 3 deletions pytext/data/squad_for_bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import itertools
from typing import List

from pytext.data.bert_tensorizer import BERTTensorizer
from pytext.data.utils import BOS, pad_and_tensorize
from pytext.config.component import ComponentType, create_component
from pytext.data.bert_tensorizer import BERTTensorizer, RoBERTaTensorizer
from pytext.data.utils import pad_and_tensorize


class SquadForBERTTensorizer(BERTTensorizer):
Expand Down Expand Up @@ -44,7 +45,8 @@ def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.idx[BOS]] + question_tokens
if self.add_bos_token:
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
Expand Down Expand Up @@ -84,3 +86,46 @@ def tensorize(self, batch):
answer_start_idx = pad_and_tensorize(answer_start_idx, self.SPAN_PAD_IDX)
answer_end_idx = pad_and_tensorize(answer_end_idx, self.SPAN_PAD_IDX)
return tokens, pad_mask, segment_labels, answer_start_idx, answer_end_idx


class SquadForRoBERTaTensorizer(SquadForBERTTensorizer, RoBERTaTensorizer):
"""Produces RoBERTa inputs and answer spans for Squad."""

class Config(RoBERTaTensorizer.Config):
columns: List[str] = ["question", "doc"]
# for labels
answers_column: str = "answers"
answer_starts_column: str = "answer_starts"
max_seq_len: int = 256

@classmethod
def from_config(cls, config: Config):
tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer)
vocab = tokenizer.vocab
return cls(
columns=config.columns,
tokenizer=tokenizer,
vocab=vocab,
answers_column=config.answers_column,
answer_starts_column=config.answer_starts_column,
max_seq_len=config.max_seq_len,
)

def __init__(
self,
columns=Config.columns,
tokenizer=None,
vocab=None,
answers_column: str = Config.answers_column,
answer_starts_column: str = Config.answer_starts_column,
max_seq_len: int = Config.max_seq_len,
):
RoBERTaTensorizer.__init__(
self, columns, tokenizer=tokenizer, vocab=vocab, max_seq_len=max_seq_len
)
self.answers_column = answers_column
self.answer_starts_column = answer_starts_column
self.add_bos_token = False

def _lookup_tokens(self, text):
return RoBERTaTensorizer._lookup_tokens(self, text)
16 changes: 14 additions & 2 deletions pytext/data/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .tokenizer import DoNothingTokenizer, Token, Tokenizer, WordPieceTokenizer
from .tokenizer import (
DoNothingTokenizer,
Gpt2Tokenizer,
Token,
Tokenizer,
WordPieceTokenizer,
)


__all__ = ["Token", "Tokenizer", "DoNothingTokenizer", "WordPieceTokenizer"]
__all__ = [
"Gpt2Tokenizer",
"Token",
"Tokenizer",
"DoNothingTokenizer",
"WordPieceTokenizer",
]
72 changes: 71 additions & 1 deletion pytext/data/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import copy
import re
from typing import List, NamedTuple

from fairseq.data.dictionary import Dictionary
from fairseq.data.encoders.gpt2_bpe import get_encoder as create_gpt2_bpe
from fairseq.data.encoders.gpt2_bpe_utils import Encoder as GPT2BPEEncoder
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType, create_component
from pytext.data.utils import Vocabulary
from pytorch_pretrained_bert.tokenization import (
BasicTokenizer,
WordpieceTokenizer,
Expand Down Expand Up @@ -145,3 +149,69 @@ def tokenize(self, input_str: str) -> List[Token]:
tokens.append(Token(sub_token, start, end))
start = end
return [token for token in tokens if token.value]


class PickleableGPT2BPEEncoder(GPT2BPEEncoder):
"""Fairseq's encoder stores the regex module as a local reference on its encoders,
which means they can't be saved via pickle.dumps or torch.save. This modified
their save/load logic doesn't store the module, and restores the reference
after re-inflating."""

def __getstate__(self):
state = vars(self)
state.pop("re")
return state

def __setstate__(self, state):
vars(self).update(state)
import regex

self.re = regex


class Gpt2Tokenizer(Tokenizer):
"""Tokenizer for gpt-2 and RoBERTa."""

class Config(ConfigBase):
token_dictionary_path: str = (
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
)
bpe_encoder_path: str = (
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/encoder.json"
)
bpe_vocab_path: str = (
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/vocab.bpe"
)

@classmethod
def from_config(cls, config: Config):
dictionary = Dictionary.load(config.token_dictionary_path)
bpe = create_gpt2_bpe(config.bpe_encoder_path, config.bpe_vocab_path)
# This hacks the bpe instance to be picklable
bpe = copy.copy(bpe)
bpe.__class__ = PickleableGPT2BPEEncoder

return cls(bpe, dictionary)

def __init__(self, bpe, dictionary: Dictionary):
self.bpe = bpe
self.vocab = Vocabulary(
dictionary.symbols,
pad_token=str(dictionary[dictionary.pad()]),
bos_token=str(dictionary[dictionary.bos()]),
eos_token=str(dictionary[dictionary.eos()]),
)
self.bos = self.vocab.bos_token
self.eos = self.vocab.eos_token

def tokenize(self, input_str: str) -> List[Token]:
bpe_ids = self.bpe.encode(input_str)
char_tokens = [self.bpe.decoder[id].lstrip(u"\u0120") for id in bpe_ids]
lengths = [len(token) for token in char_tokens]
tokens = []
end = 0
for length, id, char_token in zip(lengths, bpe_ids, char_tokens):
start = input_str.find(char_token, end)
end = start + length
tokens.append(Token(str(id), start, end))
return [token for token in tokens if token.value]
13 changes: 9 additions & 4 deletions pytext/models/qna/bert_squad_qa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Union

import torch
from pytext.common.constants import Stage
from pytext.data.squad_for_bert_tensorizer import SquadForBERTTensorizer
from pytext.data.squad_for_bert_tensorizer import (
SquadForBERTTensorizer,
SquadForRoBERTaTensorizer,
)
from pytext.data.tensorizers import LabelTensorizer
from pytext.data.utils import Vocabulary
from pytext.models.bert_classification_models import NewBertModel
Expand All @@ -21,9 +26,9 @@
class BertSquadQAModel(NewBertModel):
class Config(NewBertModel.Config):
class ModelInput(BaseModel.Config.ModelInput):
squad_input: SquadForBERTTensorizer.Config = SquadForBERTTensorizer.Config(
max_seq_len=256
)
squad_input: Union[
SquadForBERTTensorizer.Config, SquadForRoBERTaTensorizer.Config
] = SquadForBERTTensorizer.Config(max_seq_len=256)
# is_impossible label
has_answer: LabelTensorizer.Config = LabelTensorizer.Config(
column="has_answer"
Expand Down
43 changes: 43 additions & 0 deletions pytext/models/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from pytext.config import ConfigBase
from pytext.data.bert_tensorizer import RoBERTaTensorizer
from pytext.data.tensorizers import LabelTensorizer
from pytext.models.bert_classification_models import NewBertModel
from pytext.models.module import Module, create_module
from pytext.models.representations.transformer_sentence_encoder_base import (
TransformerSentenceEncoderBase,
)


class RoBERTaEncoder(TransformerSentenceEncoderBase):
class Config(TransformerSentenceEncoderBase.Config):
pretrained_encoder: Module.Config = Module.Config(
load_path=(
"manifold://pytext_training/tree/static/models/roberta_public.pt1"
)
)

def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None:
super().__init__(config, output_encoded_layers=output_encoded_layers)
assert config.pretrained_encoder.load_path, "Load path cannot be empty."
self.encoder = create_module(config.pretrained_encoder)
self.representation_dim = self.encoder.encoder.token_embedding.weight.size(-1)

def _encoder(self, inputs):
# NewBertModel expects the output as a tuple and grabs the first element
tokens, _, _ = inputs
full_representation = self.encoder(tokens)
sentence_rep = full_representation[:, 0, :]
return [full_representation], sentence_rep


class RoBERTa(NewBertModel):
class Config(NewBertModel.Config):
class InputConfig(ConfigBase):
tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
labels: LabelTensorizer.Config = LabelTensorizer.Config()

inputs: InputConfig = InputConfig()
encoder: RoBERTaEncoder.Config = RoBERTaEncoder.Config()