PyTorch bindings for Warp-ctc 报错ATen/cuda/CUDAGuard.h: No such file or directory

pytorch中没有CTCloss,需要安装第三方库Warp-ctc

warp_ctc源码地址为:https://github.com/SeanNaren/Warp-ctc

按照官网上的步骤安装

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make

安装

cd pytorch_binding
python setup.py install

产生如下错误

root@localhost:/home/ocrtrain/train/ocr/warp-ctc/pytorch_binding# python setup.py install
running install
running bdist_egg
running egg_info
creating warpctc_pytorch.egg-info
writing warpctc_pytorch.egg-info/PKG-INFO
writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt
writing top-level names to warpctc_pytorch.egg-info/top_level.txt
writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
reading manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-3.6
creating build/lib.linux-x86_64-3.6/warpctc_pytorch
copying warpctc_pytorch/init.py -> build/lib.linux-x86_64-3.6/warpctc_pytorch
running build_ext
building 'warpctc_pytorch._warp_ctc' extension
creating build/temp.linux-x86_64-3.6
creating build/temp.linux-x86_64-3.6/src
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/michael/ocrtrain/train/ocr/warp-ctc/include -I/usr/local/lib/python3.6/dist-packages/torch/include -I/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.6/dist-packages/torch/include/TH -I/usr/local/lib/python3.6/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.6m -c src/binding.cpp -o build/temp.linux-x86_64-3.6/src/binding.o -std=c++11 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0
src/binding.cpp:10:11: fatal error: ATen/cuda/CUDAGuard.h: No such file or directory
#include "ATen/cuda/CUDAGuard.h"
^~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

我用的是cuda10,版本应该是只支持到cuda9,所以报错找不到cuda头文件,我直接改为了cpu版的,

进入warp-ctc/pytorch_binding/src/目录, 找到binding.cpp文件

修改,把gpu部分去掉

#include <iostream>
#include <vector>
#include <numeric>
#include <torch/extension.h>
#include "ctc.h"
int cpu_ctc(torch::Tensor probs,
            torch::Tensor grads,
            torch::Tensor labels,
            torch::Tensor label_sizes,
            torch::Tensor sizes,
            int minibatch_size,
            torch::Tensor costs,
            int blank_label)
{
    float* probs_ptr       = (float*)probs.data_ptr();
    float* grads_ptr       = grads.storage() ? (float*)grads.data_ptr() : NULL;
    int*   sizes_ptr       = (int*)sizes.data_ptr();
    int*   labels_ptr      = (int*)labels.data_ptr();
    int*   label_sizes_ptr = (int*)label_sizes.data_ptr();
    float* costs_ptr       = (float*)costs.data_ptr();
    const int probs_size = probs.size(2);
    ctcOptions options;
    memset(&options, 0, sizeof(options));
    options.loc = CTC_CPU;
    options.num_threads = 0; // will use default number of threads
    options.blank_label = blank_label;
#if defined(CTC_DISABLE_OMP) || defined(APPLE)
    // have to use at least one
    options.num_threads = std::max(options.num_threads, (unsigned int) 1);
#endif
    size_t cpu_size_bytes;
    get_workspace_size(label_sizes_ptr, sizes_ptr,
                       probs_size, minibatch_size,
                       options, &cpu_size_bytes);
    float* cpu_workspace = new float[cpu_size_bytes / sizeof(float)];
    compute_ctc_loss(probs_ptr, grads_ptr,
                     labels_ptr, label_sizes_ptr,
                     sizes_ptr, probs_size,
                     minibatch_size, costs_ptr,
                     cpu_workspace, options);
    delete[] cpu_workspace;
    return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("cpu_ctc", &cpu_ctc, "CTC Loss function with cpu");
}

执行  python setup.py install 就可以编译运行了

第二种方法 参考

https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.zh_cn.md

https://blog.csdn.net/AMDS123/article/details/73433926

在warp-ctc根目录中运行“luarocks install http://raw.githubusercontent.com/baidu-research/warp-ctc/master/torch_binding/rocks/warp-ctc-scm-1.rockspec”。

测试一下, 则将warp-ctc/pytorch_binding/build/warpctc_pytorch 目录拷贝至与该py文件同级的目录下。vi test.py  然后 python test.py 

import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
print('PyTorch bindings for Warp-ctc')

或者 cd warp-ctc/pytorch_binding/tests && python test_cpu.py

package = "warp-ctc"
version = "scm-1"

source = {
   url = "git://github.com/baidu-research/warp-ctc.git",
}

description = {
   summary = "Baidu CTC Implementation",
   detailed = [[
   ]],
   homepage = "https://github.com/baidu-research/warp-ctc",
   license = "Apache"
}

dependencies = {
   "torch >= 7.0",
}

build = {
   type = "command",
   build_command = [[
cmake -E make_directory build && cd build && cmake .. -DLUALIB=$(LUALIB) -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) -j$(getconf _NPROCESSORS_ONLN) && make install
]],
	platforms = {},
   install_command = "cd build"
}

pytorch1.0 已经支持CTCloss了,

import torch

loss = torch.nn.CTCLoss

不过有人说自带的CTCLoss有bug,暂时还没遇到过

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值