warp_ctc源码地址为:https://github.com/SeanNaren/Warp-ctc
1.下载压缩包并解压,或直接git clone源码。
git clone https://github.com/SeanNaren/warp-ctc.git2.运行如下指令进行编译:
-
mv warp-ctc-pytorch_bindings/ warp-ctc
-
cd warp-ctc
-
mkdir build;
cd build
-
cmake ..
-
make
3.开始安装:
-
cd
pytorch_binding
-
python
setup
.py
install
4. 安装结束后,在warp-ctc/pytorch_binding/build目录下新建一个test.py文件进行测试。目录结构如下:
文件内容为:
-
import torch
-
from warpctc_pytorch
import CTCLoss
-
ctc_loss = CTCLoss()
-
# expected shape of seqLength x batchSize x alphabet_size
-
probs = torch.FloatTensor([[[
0.1,
0.6,
0.1,
0.1,
0.1], [
0.1,
0.1,
0.6,
0.1,
0.1]]]).transpose(
0,
1).contiguous()
-
labels = torch.IntTensor([
1,
2])
-
label_sizes = torch.IntTensor([
2])
-
probs_sizes = torch.IntTensor([
2])
-
probs.requires_grad_(
True)
# tells autograd to compute gradients for probs
-
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
-
cost.backward()
执行python test.py,若没有报错,则证明编译及运行成功。
6.将其应用于其他项目中。若某项目存在一py格式文件,存在如下调用方式:
则将warp-ctc/pytorch_binding/build/warpctc_pytorch 目录拷贝至与该py文件同级的目录下。
cp -r ~/warp-ctc/pytorch_binding/build/warpctc_pytorch .
再执行该py文件即可成功。