@@ -57,17 +57,25 @@ class RoBERTaEncoder(RoBERTaEncoderBase):
57
57
"""A PyTorch RoBERTa implementation"""
58
58
59
59
class Config (RoBERTaEncoderBase .Config ):
60
+ embedding_dim : int = 768
61
+ vocab_size : int = 50265
60
62
num_encoder_layers : int = 12
61
63
num_attention_heads : int = 12
62
64
model_path : str = (
63
65
"manifold://pytext_training/tree/static/models/roberta_base_torch.pt"
64
66
)
67
+ # Loading the state dict of the model depends on whether the model was
68
+ # previously finetuned in PyText or not. If it was finetuned then we
69
+ # dont need to translate the state dict and can just load it`
70
+ # directly.
71
+ is_finetuned : bool = False
65
72
66
73
def __init__ (self , config : Config , output_encoded_layers : bool , ** kwarg ) -> None :
67
74
super ().__init__ (config , output_encoded_layers = output_encoded_layers )
68
75
# assert config.pretrained_encoder.load_path, "Load path cannot be empty."
69
76
self .encoder = SentenceEncoder (
70
77
transformer = Transformer (
78
+ vocab_size = config .vocab_size ,
71
79
embedding_dim = config .embedding_dim ,
72
80
layers = [
73
81
TransformerLayer (
@@ -84,7 +92,13 @@ def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None
84
92
config .model_path ,
85
93
map_location = lambda s , l : default_restore_location (s , "cpu" ),
86
94
)
87
- self .encoder .load_roberta_state_dict (roberta_state ["model" ])
95
+ # In case the model has previously been loaded in PyText and finetuned,
96
+ # then we dont need to do the special state dict translation. Load
97
+ # it directly
98
+ if not config .is_finetuned :
99
+ self .encoder .load_roberta_state_dict (roberta_state ["model" ])
100
+ else :
101
+ self .load_state_dict (roberta_state )
88
102
self .representation_dim = self .encoder .transformer .token_embedding .weight .size (
89
103
- 1
90
104
)
0 commit comments