|
| 1 | +# CTC with Mxnet |
| 2 | +this is mx.contrib.sym.ctc_loss example. It was modified from example [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc) |
| 3 | + |
| 4 | +# Core code |
| 5 | +this is core change in lstm.py |
| 6 | +```Cython |
| 7 | +def lstm_unroll(num_lstm_layer, seq_len, |
| 8 | + num_hidden, num_label): |
| 9 | + param_cells = [] |
| 10 | + last_states = [] |
| 11 | + for i in range(num_lstm_layer): |
| 12 | + param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), |
| 13 | + i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), |
| 14 | + h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), |
| 15 | + h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) |
| 16 | + state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), |
| 17 | + h=mx.sym.Variable("l%d_init_h" % i)) |
| 18 | + last_states.append(state) |
| 19 | + assert (len(last_states) == num_lstm_layer) |
| 20 | +
|
| 21 | + # embeding layer |
| 22 | + data = mx.sym.Variable('data') |
| 23 | + label = mx.sym.Variable('label') |
| 24 | + wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) |
| 25 | +
|
| 26 | + hidden_all = [] |
| 27 | + for seqidx in range(seq_len): |
| 28 | + hidden = wordvec[seqidx] |
| 29 | + for i in range(num_lstm_layer): |
| 30 | + next_state = lstm(num_hidden, indata=hidden, |
| 31 | + prev_state=last_states[i], |
| 32 | + param=param_cells[i], |
| 33 | + seqidx=seqidx, layeridx=i) |
| 34 | + hidden = next_state.h |
| 35 | + last_states[i] = next_state |
| 36 | + hidden_all.append(hidden) |
| 37 | +
|
| 38 | + hidden_concat = mx.sym.Concat(*hidden_all, dim=0) |
| 39 | +
|
| 40 | + pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) |
| 41 | + pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) |
| 42 | +
|
| 43 | + loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label) |
| 44 | + ctc_loss = mx.sym.MakeLoss(loss) |
| 45 | +
|
| 46 | + softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) |
| 47 | + softmax_loss = mx.sym.MakeLoss(softmax_class) |
| 48 | + softmax_loss = mx.sym.BlockGrad(softmax_loss) |
| 49 | +
|
| 50 | + return mx.sym.Group([softmax_loss, ctc_loss]) |
| 51 | +``` |
| 52 | +# Some Result |
| 53 | +If there were more training, the result would be better |
| 54 | + |
| 55 | +``` |
| 56 | +2017-07-08 13:22:01,155 Epoch[94] Batch [50] Speed: 4273.43 samples/sec Accuracy=0.808747 |
| 57 | +2017-07-08 13:22:13,141 Epoch[94] Batch [100] Speed: 4271.84 samples/sec Accuracy=0.786855 |
| 58 | +2017-07-08 13:22:25,179 Epoch[94] Batch [150] Speed: 4253.81 samples/sec Accuracy=0.810625 |
| 59 | +2017-07-08 13:22:37,198 Epoch[94] Batch [200] Speed: 4259.96 samples/sec Accuracy=0.808809 |
| 60 | +2017-07-08 13:22:49,233 Epoch[94] Batch [250] Speed: 4254.13 samples/sec Accuracy=0.806426 |
| 61 | +2017-07-08 13:23:01,308 Epoch[94] Batch [300] Speed: 4239.98 samples/sec Accuracy=0.817305 |
| 62 | +2017-07-08 13:23:02,030 Epoch[94] Train-Accuracy=0.819336 |
| 63 | +2017-07-08 13:23:02,030 Epoch[94] Time cost=73.092 |
| 64 | +2017-07-08 13:23:02,101 Saved checkpoint to "ocr-0095.params" |
| 65 | +2017-07-08 13:23:07,192 Epoch[94] Validation-Accuracy=0.819417 |
| 66 | +2017-07-08 13:23:20,579 Epoch[95] Batch [50] Speed: 4288.76 samples/sec Accuracy=0.817459 |
| 67 | +2017-07-08 13:23:32,573 Epoch[95] Batch [100] Speed: 4268.75 samples/sec Accuracy=0.815215 |
| 68 | +2017-07-08 13:23:44,635 Epoch[95] Batch [150] Speed: 4244.85 samples/sec Accuracy=0.820215 |
| 69 | +2017-07-08 13:23:56,670 Epoch[95] Batch [200] Speed: 4254.38 samples/sec Accuracy=0.823613 |
| 70 | +2017-07-08 13:24:08,650 Epoch[95] Batch [250] Speed: 4273.83 samples/sec Accuracy=0.827109 |
| 71 | +2017-07-08 13:24:20,680 Epoch[95] Batch [300] Speed: 4256.49 samples/sec Accuracy=0.824961 |
| 72 | +2017-07-08 13:24:21,401 Epoch[95] Train-Accuracy=0.840495 |
| 73 | +2017-07-08 13:24:21,401 Epoch[95] Time cost=73.008 |
| 74 | +2017-07-08 13:24:21,441 Saved checkpoint to "ocr-0096.params" |
| 75 | +2017-07-08 13:24:26,508 Epoch[95] Validation-Accuracy=0.834798 |
| 76 | +2017-07-08 13:24:39,938 Epoch[96] Batch [50] Speed: 4259.32 samples/sec Accuracy=0.825578 |
| 77 | +2017-07-08 13:24:51,987 Epoch[96] Batch [100] Speed: 4249.67 samples/sec Accuracy=0.826562 |
| 78 | +2017-07-08 13:25:04,041 Epoch[96] Batch [150] Speed: 4247.44 samples/sec Accuracy=0.831855 |
| 79 | +2017-07-08 13:25:16,058 Epoch[96] Batch [200] Speed: 4260.77 samples/sec Accuracy=0.830840 |
| 80 | +2017-07-08 13:25:28,109 Epoch[96] Batch [250] Speed: 4248.44 samples/sec Accuracy=0.827168 |
| 81 | +2017-07-08 13:25:40,057 Epoch[96] Batch [300] Speed: 4285.23 samples/sec Accuracy=0.832715 |
| 82 | +2017-07-08 13:25:40,782 Epoch[96] Train-Accuracy=0.830729 |
| 83 | +2017-07-08 13:25:40,782 Epoch[96] Time cost=73.098 |
| 84 | +2017-07-08 13:25:40,821 Saved checkpoint to "ocr-0097.params" |
| 85 | +2017-07-08 13:25:45,886 Epoch[96] Validation-Accuracy=0.840820 |
| 86 | +2017-07-08 13:25:59,283 Epoch[97] Batch [50] Speed: 4271.85 samples/sec Accuracy=0.831648 |
| 87 | +2017-07-08 13:26:11,243 Epoch[97] Batch [100] Speed: 4280.89 samples/sec Accuracy=0.835371 |
| 88 | +2017-07-08 13:26:23,263 Epoch[97] Batch [150] Speed: 4259.89 samples/sec Accuracy=0.831094 |
| 89 | +2017-07-08 13:26:35,230 Epoch[97] Batch [200] Speed: 4278.40 samples/sec Accuracy=0.827129 |
| 90 | +2017-07-08 13:26:47,199 Epoch[97] Batch [250] Speed: 4277.77 samples/sec Accuracy=0.834258 |
| 91 | +2017-07-08 13:26:59,257 Epoch[97] Batch [300] Speed: 4245.93 samples/sec Accuracy=0.833770 |
| 92 | +2017-07-08 13:26:59,971 Epoch[97] Train-Accuracy=0.844727 |
| 93 | +2017-07-08 13:26:59,971 Epoch[97] Time cost=72.908 |
| 94 | +2017-07-08 13:27:00,020 Saved checkpoint to "ocr-0098.params" |
| 95 | +2017-07-08 13:27:05,130 Epoch[97] Validation-Accuracy=0.827962 |
| 96 | +2017-07-08 13:27:18,521 Epoch[98] Batch [50] Speed: 4281.06 samples/sec Accuracy=0.834118 |
| 97 | +2017-07-08 13:27:30,537 Epoch[98] Batch [100] Speed: 4261.20 samples/sec Accuracy=0.835352 |
| 98 | +2017-07-08 13:27:42,542 Epoch[98] Batch [150] Speed: 4264.88 samples/sec Accuracy=0.839395 |
| 99 | +2017-07-08 13:27:54,544 Epoch[98] Batch [200] Speed: 4266.31 samples/sec Accuracy=0.836328 |
| 100 | +2017-07-08 13:28:06,550 Epoch[98] Batch [250] Speed: 4264.50 samples/sec Accuracy=0.841465 |
| 101 | +2017-07-08 13:28:18,622 Epoch[98] Batch [300] Speed: 4241.11 samples/sec Accuracy=0.831680 |
| 102 | +2017-07-08 13:28:19,349 Epoch[98] Train-Accuracy=0.833984 |
| 103 | +2017-07-08 13:28:19,349 Epoch[98] Time cost=73.018 |
| 104 | +2017-07-08 13:28:19,393 Saved checkpoint to "ocr-0099.params" |
| 105 | +2017-07-08 13:28:24,472 Epoch[98] Validation-Accuracy=0.818034 |
| 106 | +2017-07-08 13:28:37,961 Epoch[99] Batch [50] Speed: 4242.14 samples/sec Accuracy=0.835861 |
| 107 | +2017-07-08 13:28:50,031 Epoch[99] Batch [100] Speed: 4241.94 samples/sec Accuracy=0.846543 |
| 108 | +2017-07-08 13:29:02,108 Epoch[99] Batch [150] Speed: 4239.22 samples/sec Accuracy=0.850645 |
| 109 | +2017-07-08 13:29:14,160 Epoch[99] Batch [200] Speed: 4248.34 samples/sec Accuracy=0.844141 |
| 110 | +2017-07-08 13:29:26,225 Epoch[99] Batch [250] Speed: 4243.71 samples/sec Accuracy=0.842129 |
| 111 | +2017-07-08 13:29:38,277 Epoch[99] Batch [300] Speed: 4248.07 samples/sec Accuracy=0.851250 |
| 112 | +2017-07-08 13:29:38,975 Epoch[99] Train-Accuracy=0.854492 |
| 113 | +2017-07-08 13:29:38,976 Epoch[99] Time cost=73.315 |
| 114 | +2017-07-08 13:29:39,023 Saved checkpoint to "ocr-0100.params" |
| 115 | +2017-07-08 13:29:44,110 Epoch[99] Validation-Accuracy=0.851969 |
| 116 | +``` |
0 commit comments