CTCLOSS backward报错

在进行中文训练时遇到CTCLOSS反向传播错误,通过测试warpctc代码发现问题。解决方案是修改$PATH_TO_warp-ctc/pytorch_binding/warpctc_pytorch/__init__.py文件中的backward()函数返回值。此外,训练程序中的DataParallel导致keyError,需在demo.py中添加相应语句。还遇到了维度错误,通过调整网络维度和注释代码解决。最终选择TensorFlow版crnn,简化操作。
摘要由CSDN通过智能技术生成

利用该项目进行中文训练时,程序在ctcloss反向传播时出现问题。为验证ctcloss的有效性,网上找了一段测试代码。

import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()

如上为此处测试warpctc的一段代码,运行后报错。

>>> cost.backward()
Traceback (most recent call last):
  File "<stdin>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值