CTCLoss模块安装:
CRNN.pytorch源代码:
meijieru/crnn.pytorch: Convolutional recurrent network in pytorch (github.com)
方式一:
pytorch
(version>=1.1)含有内置CTCLoss模块,所以只需要更改CRNN文档中相应的几行代码应用即可。
train.py中修改如下几行:
<1> from warpctc_pytorch import CTCLoss ——>from torch.nn import CTCLoss
<2>criterion = CTCLoss() ——>criterion = CTCLoss(blank=0, reduction='mean')
<3> cost = criterion(preds, text, preds_size, length) / batch_size ——> cost = criterion(preds.log_softmax(2), text, preds_size, length) / batch_size
方式二:
额外安装warp_ctc_pytorch(Windows环境,需要安装cmake来编译文件): https://github.com/SeanNaren/warp-ctc
git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding
python setup.py install