Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit e4df5c1

Browse files
Kartikay Khandelwalfacebook-github-bot
authored andcommitted
Integrate XLM-R into PyText
Summary: Adding the ability to load and finetune XLM-R models in PyText. Reviewed By: rutyrinott Differential Revision: D18382033 fbshipit-source-id: e201b2f50129814950784fc255f4e6bfb4610352
1 parent be627c1 commit e4df5c1

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

pytext/data/roberta_tensorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Config(BERTTensorizerBase.Config):
1919
vocab_file: str = (
2020
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
2121
)
22-
tokenizer: GPT2BPETokenizer.Config = GPT2BPETokenizer.Config()
22+
tokenizer: Tokenizer.Config = GPT2BPETokenizer.Config()
2323
max_seq_len: int = 256
2424

2525
@classmethod

pytext/docs/source/xlm_r.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa)
2+
3+
## Introduction
4+
5+
XLM-R (XLM-RoBERTa) is scaled cross lingual sentence encoder. It is trained on `2.5T` of data across `100` languages data filtered from Common Crawl. XLM-R achieves state-of-the-arts results on multiple cross lingual benchmarks.
6+
7+
## Pre-trained models
8+
9+
Model | Description | #params | vocab size | Download
10+
---|---|---|---|---
11+
`xlmr.base.v0` | XLM-R using the BERT-base architecture | 250M | 250k | [xlm.base.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz)
12+
`xlmr.large.v0` | XLM-R using the BERT-large architecture | 560M | 250k | [xlm.large.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz)
13+
14+
(Note: The above models are still under training, we will update the weights, once fully trained, the results are based on the above checkpoints.)
15+
16+
## Results
17+
18+
**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
19+
20+
Model | average | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
21+
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
22+
`roberta.large.mnli` _(TRANSLATE-TEST)_ | 77.8 | 91.3 | 82.9 | 84.3 | 81.2 | 81.7 | 83.1 | 78.3 | 76.8 | 76.6 | 74.2 | 74.1 | 77.5 | 70.9 | 66.7 | 66.8
23+
`xlmr.large.v0` _(TRANSLATE-TRAIN-ALL)_ | **82.4** | 88.7 | 85.2 | 85.6 | 84.6 | 83.6 | 85.5 | 82.4 | 81.6 | 80.9 | 83.4 | 80.9 | 83.3 | 79.8 | 75.9 | 74.3
24+
25+
**[MLQA (Lewis et al., 2018)](https://arxiv.org/abs/1910.07475)**
26+
27+
Model | average | en | es | de | ar | hi | vi | zh
28+
---|---|---|---|---|---|---|---|---
29+
`BERT-large` | - | 80.2/67.4 | - | - | - | - | - | -
30+
`mBERT` | 57.7 / 41.6 | 77.7 / 65.2 | 64.3 / 46.6 | 57.9 / 44.3 | 45.7 / 29.8| 43.8 / 29.7 | 57.1 / 38.6 | 57.5 / 37.3
31+
`xlmr.large.v0` | **70.0 / 52.2** | 80.1 / 67.7 | 73.2 / 55.1 | 68.3 / 53.7 | 62.8 / 43.7 | 68.3 / 51.0 | 70.5 / 50.1 | 67.1 / 44.4
32+
33+
34+
## Citation
35+
36+
```bibtex
37+
@article{,
38+
title = {Unsupervised Cross-lingual Representation Learning at Scale},
39+
author = {Alexis Conneau and Kartikay Khandelwal
40+
and Naman Goyal and Vishrav Chaudhary and Guillaume Wenzek
41+
and Francisco Guzm\'an and Edouard Grave and Myle Ott
42+
and Luke Zettlemoyer and Veselin Stoyanov
43+
},
44+
journal={},
45+
year = {2019},
46+
}
47+
```

pytext/models/roberta.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,25 @@ class RoBERTaEncoder(RoBERTaEncoderBase):
5757
"""A PyTorch RoBERTa implementation"""
5858

5959
class Config(RoBERTaEncoderBase.Config):
60+
embedding_dim: int = 768
61+
vocab_size: int = 50265
6062
num_encoder_layers: int = 12
6163
num_attention_heads: int = 12
6264
model_path: str = (
6365
"manifold://pytext_training/tree/static/models/roberta_base_torch.pt"
6466
)
67+
# Loading the state dict of the model depends on whether the model was
68+
# previously finetuned in PyText or not. If it was finetuned then we
69+
# dont need to translate the state dict and can just load it`
70+
# directly.
71+
is_finetuned: bool = False
6572

6673
def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None:
6774
super().__init__(config, output_encoded_layers=output_encoded_layers)
6875
# assert config.pretrained_encoder.load_path, "Load path cannot be empty."
6976
self.encoder = SentenceEncoder(
7077
transformer=Transformer(
78+
vocab_size=config.vocab_size,
7179
embedding_dim=config.embedding_dim,
7280
layers=[
7381
TransformerLayer(
@@ -84,7 +92,13 @@ def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None
8492
config.model_path,
8593
map_location=lambda s, l: default_restore_location(s, "cpu"),
8694
)
87-
self.encoder.load_roberta_state_dict(roberta_state["model"])
95+
# In case the model has previously been loaded in PyText and finetuned,
96+
# then we dont need to do the special state dict translation. Load
97+
# it directly
98+
if not config.is_finetuned:
99+
self.encoder.load_roberta_state_dict(roberta_state["model"])
100+
else:
101+
self.load_state_dict(roberta_state)
88102
self.representation_dim = self.encoder.transformer.token_embedding.weight.size(
89103
-1
90104
)

0 commit comments

Comments
 (0)