Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 16e4f6f

Browse files
Xiaojian Wufacebook-github-bot
Xiaojian Wu
authored andcommitted
Add Roberta model into BertPairwiseModel (#1336)
Summary: Pull Request resolved: #1336 BertPairwiseModel only supports BERTTensorizer so we can't use Roberta as encoder. This change enables RobertaTensorizer. Reviewed By: liaimi Differential Revision: D21054324 fbshipit-source-id: ad792e5d80ada73be29d2feacaec4c3bfdec2f82
1 parent 942a49f commit 16e4f6f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytext/models/bert_classification_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from pytext.common.constants import Stage
99
from pytext.config.component import create_loss
10-
from pytext.data.bert_tensorizer import BERTTensorizer
10+
from pytext.data.bert_tensorizer import BERTTensorizer, BERTTensorizerBase
1111
from pytext.data.dense_retrieval_tensorizer import ( # noqa
1212
BERTContextTensorizerForDenseRetrieval,
1313
PositiveLabelTensorizerForDenseRetrieval,
@@ -139,10 +139,10 @@ class BertPairwiseModel(BasePairwiseModel):
139139

140140
class Config(BasePairwiseModel.Config):
141141
class ModelInput(ModelInputBase):
142-
tokens1: BERTTensorizer.Config = BERTTensorizer.Config(
142+
tokens1: BERTTensorizerBase.Config = BERTTensorizer.Config(
143143
columns=["text1"], max_seq_len=128
144144
)
145-
tokens2: BERTTensorizer.Config = BERTTensorizer.Config(
145+
tokens2: BERTTensorizerBase.Config = BERTTensorizer.Config(
146146
columns=["text2"], max_seq_len=128
147147
)
148148
labels: LabelTensorizer.Config = LabelTensorizer.Config()

0 commit comments

Comments
 (0)