@@ -53,7 +53,7 @@ def build_argparse():
53
53
54
54
parser .add_argument ('--hidden_dim' , type = int , default = 100 )
55
55
parser .add_argument ('--emb_dim' , type = int , default = 50 )
56
- parser .add_argument ('--num_layers' , type = int , default = 1 )
56
+ parser .add_argument ('--num_layers' , type = int , default = None , help = 'Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier' )
57
57
parser .add_argument ('--emb_dropout' , type = float , default = 0.5 )
58
58
parser .add_argument ('--dropout' , type = float , default = 0.5 )
59
59
parser .add_argument ('--max_dec_len' , type = int , default = 50 )
@@ -153,6 +153,12 @@ def train(args):
153
153
args ['vocab_size' ] = vocab .size
154
154
dev_batch = BinaryDataLoader (dev_doc , args ['batch_size' ], args , vocab = vocab , evaluation = True )
155
155
156
+ if args ['num_layers' ] is None :
157
+ if args ['force_exact_pieces' ]:
158
+ args ['num_layers' ] = 2
159
+ else :
160
+ args ['num_layers' ] = 1
161
+
156
162
# train a dictionary-based MWT expander
157
163
trainer = Trainer (args = args , vocab = vocab , device = args ['device' ])
158
164
logger .info ("Training dictionary-based MWT expander..." )
0 commit comments