Skip to content

Commit b999102

Browse files
committed
Set a default for num_layers based on the model type - num_layers=2 seems good for the character classifier
1 parent 7eb3a50 commit b999102

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

stanza/models/mwt_expander.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_argparse():
5353

5454
parser.add_argument('--hidden_dim', type=int, default=100)
5555
parser.add_argument('--emb_dim', type=int, default=50)
56-
parser.add_argument('--num_layers', type=int, default=1)
56+
parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier')
5757
parser.add_argument('--emb_dropout', type=float, default=0.5)
5858
parser.add_argument('--dropout', type=float, default=0.5)
5959
parser.add_argument('--max_dec_len', type=int, default=50)
@@ -153,6 +153,12 @@ def train(args):
153153
args['vocab_size'] = vocab.size
154154
dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
155155

156+
if args['num_layers'] is None:
157+
if args['force_exact_pieces']:
158+
args['num_layers'] = 2
159+
else:
160+
args['num_layers'] = 1
161+
156162
# train a dictionary-based MWT expander
157163
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
158164
logger.info("Training dictionary-based MWT expander...")

0 commit comments

Comments
 (0)