Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9766815

Browse files
authored
Merge branch 'master' into fop
2 parents b0cf3ca + f205ffd commit 9766815

File tree

5 files changed

+532
-2
lines changed

5 files changed

+532
-2
lines changed

example/ctc/README.md

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# CTC with Mxnet
2+
this is mx.contrib.sym.ctc_loss example. It was modified from example [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc)
3+
4+
# Core code
5+
this is core change in lstm.py
6+
```Cython
7+
def lstm_unroll(num_lstm_layer, seq_len,
8+
num_hidden, num_label):
9+
param_cells = []
10+
last_states = []
11+
for i in range(num_lstm_layer):
12+
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
13+
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
14+
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
15+
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
16+
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
17+
h=mx.sym.Variable("l%d_init_h" % i))
18+
last_states.append(state)
19+
assert (len(last_states) == num_lstm_layer)
20+
21+
# embeding layer
22+
data = mx.sym.Variable('data')
23+
label = mx.sym.Variable('label')
24+
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
25+
26+
hidden_all = []
27+
for seqidx in range(seq_len):
28+
hidden = wordvec[seqidx]
29+
for i in range(num_lstm_layer):
30+
next_state = lstm(num_hidden, indata=hidden,
31+
prev_state=last_states[i],
32+
param=param_cells[i],
33+
seqidx=seqidx, layeridx=i)
34+
hidden = next_state.h
35+
last_states[i] = next_state
36+
hidden_all.append(hidden)
37+
38+
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
39+
40+
pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
41+
pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0))
42+
43+
loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label)
44+
ctc_loss = mx.sym.MakeLoss(loss)
45+
46+
softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc)
47+
softmax_loss = mx.sym.MakeLoss(softmax_class)
48+
softmax_loss = mx.sym.BlockGrad(softmax_loss)
49+
50+
return mx.sym.Group([softmax_loss, ctc_loss])
51+
```
52+
# Some Result
53+
If there were more training, the result would be better
54+
55+
```
56+
2017-07-08 13:22:01,155 Epoch[94] Batch [50] Speed: 4273.43 samples/sec Accuracy=0.808747
57+
2017-07-08 13:22:13,141 Epoch[94] Batch [100] Speed: 4271.84 samples/sec Accuracy=0.786855
58+
2017-07-08 13:22:25,179 Epoch[94] Batch [150] Speed: 4253.81 samples/sec Accuracy=0.810625
59+
2017-07-08 13:22:37,198 Epoch[94] Batch [200] Speed: 4259.96 samples/sec Accuracy=0.808809
60+
2017-07-08 13:22:49,233 Epoch[94] Batch [250] Speed: 4254.13 samples/sec Accuracy=0.806426
61+
2017-07-08 13:23:01,308 Epoch[94] Batch [300] Speed: 4239.98 samples/sec Accuracy=0.817305
62+
2017-07-08 13:23:02,030 Epoch[94] Train-Accuracy=0.819336
63+
2017-07-08 13:23:02,030 Epoch[94] Time cost=73.092
64+
2017-07-08 13:23:02,101 Saved checkpoint to "ocr-0095.params"
65+
2017-07-08 13:23:07,192 Epoch[94] Validation-Accuracy=0.819417
66+
2017-07-08 13:23:20,579 Epoch[95] Batch [50] Speed: 4288.76 samples/sec Accuracy=0.817459
67+
2017-07-08 13:23:32,573 Epoch[95] Batch [100] Speed: 4268.75 samples/sec Accuracy=0.815215
68+
2017-07-08 13:23:44,635 Epoch[95] Batch [150] Speed: 4244.85 samples/sec Accuracy=0.820215
69+
2017-07-08 13:23:56,670 Epoch[95] Batch [200] Speed: 4254.38 samples/sec Accuracy=0.823613
70+
2017-07-08 13:24:08,650 Epoch[95] Batch [250] Speed: 4273.83 samples/sec Accuracy=0.827109
71+
2017-07-08 13:24:20,680 Epoch[95] Batch [300] Speed: 4256.49 samples/sec Accuracy=0.824961
72+
2017-07-08 13:24:21,401 Epoch[95] Train-Accuracy=0.840495
73+
2017-07-08 13:24:21,401 Epoch[95] Time cost=73.008
74+
2017-07-08 13:24:21,441 Saved checkpoint to "ocr-0096.params"
75+
2017-07-08 13:24:26,508 Epoch[95] Validation-Accuracy=0.834798
76+
2017-07-08 13:24:39,938 Epoch[96] Batch [50] Speed: 4259.32 samples/sec Accuracy=0.825578
77+
2017-07-08 13:24:51,987 Epoch[96] Batch [100] Speed: 4249.67 samples/sec Accuracy=0.826562
78+
2017-07-08 13:25:04,041 Epoch[96] Batch [150] Speed: 4247.44 samples/sec Accuracy=0.831855
79+
2017-07-08 13:25:16,058 Epoch[96] Batch [200] Speed: 4260.77 samples/sec Accuracy=0.830840
80+
2017-07-08 13:25:28,109 Epoch[96] Batch [250] Speed: 4248.44 samples/sec Accuracy=0.827168
81+
2017-07-08 13:25:40,057 Epoch[96] Batch [300] Speed: 4285.23 samples/sec Accuracy=0.832715
82+
2017-07-08 13:25:40,782 Epoch[96] Train-Accuracy=0.830729
83+
2017-07-08 13:25:40,782 Epoch[96] Time cost=73.098
84+
2017-07-08 13:25:40,821 Saved checkpoint to "ocr-0097.params"
85+
2017-07-08 13:25:45,886 Epoch[96] Validation-Accuracy=0.840820
86+
2017-07-08 13:25:59,283 Epoch[97] Batch [50] Speed: 4271.85 samples/sec Accuracy=0.831648
87+
2017-07-08 13:26:11,243 Epoch[97] Batch [100] Speed: 4280.89 samples/sec Accuracy=0.835371
88+
2017-07-08 13:26:23,263 Epoch[97] Batch [150] Speed: 4259.89 samples/sec Accuracy=0.831094
89+
2017-07-08 13:26:35,230 Epoch[97] Batch [200] Speed: 4278.40 samples/sec Accuracy=0.827129
90+
2017-07-08 13:26:47,199 Epoch[97] Batch [250] Speed: 4277.77 samples/sec Accuracy=0.834258
91+
2017-07-08 13:26:59,257 Epoch[97] Batch [300] Speed: 4245.93 samples/sec Accuracy=0.833770
92+
2017-07-08 13:26:59,971 Epoch[97] Train-Accuracy=0.844727
93+
2017-07-08 13:26:59,971 Epoch[97] Time cost=72.908
94+
2017-07-08 13:27:00,020 Saved checkpoint to "ocr-0098.params"
95+
2017-07-08 13:27:05,130 Epoch[97] Validation-Accuracy=0.827962
96+
2017-07-08 13:27:18,521 Epoch[98] Batch [50] Speed: 4281.06 samples/sec Accuracy=0.834118
97+
2017-07-08 13:27:30,537 Epoch[98] Batch [100] Speed: 4261.20 samples/sec Accuracy=0.835352
98+
2017-07-08 13:27:42,542 Epoch[98] Batch [150] Speed: 4264.88 samples/sec Accuracy=0.839395
99+
2017-07-08 13:27:54,544 Epoch[98] Batch [200] Speed: 4266.31 samples/sec Accuracy=0.836328
100+
2017-07-08 13:28:06,550 Epoch[98] Batch [250] Speed: 4264.50 samples/sec Accuracy=0.841465
101+
2017-07-08 13:28:18,622 Epoch[98] Batch [300] Speed: 4241.11 samples/sec Accuracy=0.831680
102+
2017-07-08 13:28:19,349 Epoch[98] Train-Accuracy=0.833984
103+
2017-07-08 13:28:19,349 Epoch[98] Time cost=73.018
104+
2017-07-08 13:28:19,393 Saved checkpoint to "ocr-0099.params"
105+
2017-07-08 13:28:24,472 Epoch[98] Validation-Accuracy=0.818034
106+
2017-07-08 13:28:37,961 Epoch[99] Batch [50] Speed: 4242.14 samples/sec Accuracy=0.835861
107+
2017-07-08 13:28:50,031 Epoch[99] Batch [100] Speed: 4241.94 samples/sec Accuracy=0.846543
108+
2017-07-08 13:29:02,108 Epoch[99] Batch [150] Speed: 4239.22 samples/sec Accuracy=0.850645
109+
2017-07-08 13:29:14,160 Epoch[99] Batch [200] Speed: 4248.34 samples/sec Accuracy=0.844141
110+
2017-07-08 13:29:26,225 Epoch[99] Batch [250] Speed: 4243.71 samples/sec Accuracy=0.842129
111+
2017-07-08 13:29:38,277 Epoch[99] Batch [300] Speed: 4248.07 samples/sec Accuracy=0.851250
112+
2017-07-08 13:29:38,975 Epoch[99] Train-Accuracy=0.854492
113+
2017-07-08 13:29:38,976 Epoch[99] Time cost=73.315
114+
2017-07-08 13:29:39,023 Saved checkpoint to "ocr-0100.params"
115+
2017-07-08 13:29:44,110 Epoch[99] Validation-Accuracy=0.851969
116+
```

example/ctc/lstm.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# pylint:skip-file
2+
import sys
3+
4+
from mxnet.symbol_doc import SymbolDoc
5+
6+
sys.path.insert(0, "../../python")
7+
import mxnet as mx
8+
import numpy as np
9+
from collections import namedtuple
10+
import time
11+
import math
12+
13+
LSTMState = namedtuple("LSTMState", ["c", "h"])
14+
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
15+
"h2h_weight", "h2h_bias"])
16+
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
17+
"init_states", "last_states",
18+
"seq_data", "seq_labels", "seq_outputs",
19+
"param_blocks"])
20+
21+
22+
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
23+
"""LSTM Cell symbol"""
24+
i2h = mx.sym.FullyConnected(data=indata,
25+
weight=param.i2h_weight,
26+
bias=param.i2h_bias,
27+
num_hidden=num_hidden * 4,
28+
name="t%d_l%d_i2h" % (seqidx, layeridx))
29+
h2h = mx.sym.FullyConnected(data=prev_state.h,
30+
weight=param.h2h_weight,
31+
bias=param.h2h_bias,
32+
num_hidden=num_hidden * 4,
33+
name="t%d_l%d_h2h" % (seqidx, layeridx))
34+
gates = i2h + h2h
35+
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
36+
name="t%d_l%d_slice" % (seqidx, layeridx))
37+
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
38+
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
39+
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
40+
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
41+
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
42+
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
43+
return LSTMState(c=next_c, h=next_h)
44+
45+
46+
def lstm_unroll(num_lstm_layer, seq_len,
47+
num_hidden, num_label):
48+
param_cells = []
49+
last_states = []
50+
for i in range(num_lstm_layer):
51+
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
52+
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
53+
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
54+
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
55+
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
56+
h=mx.sym.Variable("l%d_init_h" % i))
57+
last_states.append(state)
58+
assert (len(last_states) == num_lstm_layer)
59+
60+
# embeding layer
61+
data = mx.sym.Variable('data')
62+
label = mx.sym.Variable('label')
63+
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
64+
65+
hidden_all = []
66+
for seqidx in range(seq_len):
67+
hidden = wordvec[seqidx]
68+
for i in range(num_lstm_layer):
69+
next_state = lstm(num_hidden, indata=hidden,
70+
prev_state=last_states[i],
71+
param=param_cells[i],
72+
seqidx=seqidx, layeridx=i)
73+
hidden = next_state.h
74+
last_states[i] = next_state
75+
hidden_all.append(hidden)
76+
77+
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
78+
79+
pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
80+
pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0))
81+
82+
loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label)
83+
ctc_loss = mx.sym.MakeLoss(loss)
84+
85+
softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc)
86+
softmax_loss = mx.sym.MakeLoss(softmax_class)
87+
softmax_loss = mx.sym.BlockGrad(softmax_loss)
88+
89+
return mx.sym.Group([softmax_loss, ctc_loss])

0 commit comments

Comments
 (0)