Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit ce61992

Browse files
Kartikay Khandelwalfacebook-github-bot
authored andcommitted
Integrate XLM-R into PyText (#1120)
Summary: Pull Request resolved: #1120 Adding the ability to load and finetune XLM-R models in PyText. Reviewed By: rutyrinott Differential Revision: D18382033 fbshipit-source-id: 157a53fb44b46452fed7005db9682c9dc46f28da
1 parent 901587d commit ce61992

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

pytext/data/roberta_tensorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Config(BERTTensorizerBase.Config):
1919
vocab_file: str = (
2020
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
2121
)
22-
tokenizer: GPT2BPETokenizer.Config = GPT2BPETokenizer.Config()
22+
tokenizer: Tokenizer.Config = GPT2BPETokenizer.Config()
2323
max_seq_len: int = 256
2424

2525
@classmethod

pytext/models/roberta.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,25 @@ class RoBERTaEncoder(RoBERTaEncoderBase):
5757
"""A PyTorch RoBERTa implementation"""
5858

5959
class Config(RoBERTaEncoderBase.Config):
60+
embedding_dim: int = 768
61+
vocab_size: int = 50265
6062
num_encoder_layers: int = 12
6163
num_attention_heads: int = 12
6264
model_path: str = (
6365
"manifold://pytext_training/tree/static/models/roberta_base_torch.pt"
6466
)
67+
# Loading the state dict of the model depends on whether the model was
68+
# previously finetuned in PyText or not. If it was finetuned then we
69+
# dont need to translate the state dict and can just load it`
70+
# directly.
71+
is_finetuned: bool = False
6572

6673
def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None:
6774
super().__init__(config, output_encoded_layers=output_encoded_layers)
6875
# assert config.pretrained_encoder.load_path, "Load path cannot be empty."
6976
self.encoder = SentenceEncoder(
7077
transformer=Transformer(
78+
vocab_size=config.vocab_size,
7179
embedding_dim=config.embedding_dim,
7280
layers=[
7381
TransformerLayer(
@@ -84,7 +92,13 @@ def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None
8492
config.model_path,
8593
map_location=lambda s, l: default_restore_location(s, "cpu"),
8694
)
87-
self.encoder.load_roberta_state_dict(roberta_state["model"])
95+
# In case the model has previously been loaded in PyText and finetuned,
96+
# then we dont need to do the special state dict translation. Load
97+
# it directly
98+
if not config.is_finetuned:
99+
self.encoder.load_roberta_state_dict(roberta_state["model"])
100+
else:
101+
self.load_state_dict(roberta_state)
88102
self.representation_dim = self.encoder.transformer.token_embedding.weight.size(
89103
-1
90104
)

0 commit comments

Comments
 (0)