pytorch内置torch.nn.CTCLoss

一、开篇简述

CTC 的全称是Connectionist Temporal Classification,中文名称是“连接时序分类”,这个方法主要是解决神经网络label 和output 不对齐的问题(Alignment problem),其优点是不用强制对齐标签且标签可变长,仅需输入序列和监督标签序列即可进行训练,目前,该方法主要应用于场景文本识别(scene text recognition)、语音识别(speech recognition)及手写字识别(handwriting recognition)等工程场景。以往我们在百度上搜索pytorch + ctc loss得到的结果基本上warp-ctc的使用方法,warp-ctc是百度开源的一个可以应用在CPU和GPU上高效并行的CTC代码库,但是为了在pytorch上使用warp-ctc我们不仅需要编译其源代码还需要进行安装配置,使用起来着实麻烦。而在Pytorch 1.0.x版本内早就有内置ctc loss接口了,我们完全可以直接使用,只是很少有资料介绍如何使用该API。因此,本篇文章结合我个人工程实践中的经验介绍我在pytorch中使用其内置torch.nn.CTCLoss的方法,但不会对ctc loss原理进行展开,期望能给大家在工程实践中使用torch.nn.CTCLoss带来帮助!

二、CTCLoss接口使用说明

第一步,获取CTCLoss()对象

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')

类初始化参数说明:

blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;

reduction:处理output losses的方式,string类型,可选’none’ 、 ‘mean’ 及 ‘sum’,’none’表示对output losses不做任何处理,’mean’ 则对output losses取平均值处理,’sum’则是对output losses求和处理,默认为’mean’ 。

第二步,在迭代中调用CTCLoss()对象计算损失值

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

CTCLoss()对象调用形参说明:

log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;

targets:shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;

input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;

target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;

举个具体例子说明如何使用CTCLoss(),如下为CTCLoss在车牌识别里面的应用:

比如我们需要预测的字符集如下,其中’-‘表示空白标签;

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

因为空白标签所在的位置为len(CHARS)-1,而我们需要处理CTCLoss output losses的方式为‘mean’,则需要按照如下方式初始化CTCLoss类:

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction=’mean’)

我们设定输出序列长度T为18,训练批大小N为4且训练数据集仅有4张车牌(为了方便说明)如下,总的字符集长度C如上面CHARS所示为68:

BVjUvm.pnguploading.4e448015.gif转存失败重新上传取消《如何优雅的使用pytorch内置torch.nn.CTCLoss的方法》

那么我们在训练一次迭代中打印各个输入形参得出如下结果:

1)log_probs由于数值比较多且为神经网络前向输出结果,我们仅打印其shape出来,如下:

torch.Size([18, 4, 68])

2)打印targets如下,表示这四张车牌的训练标签,根据target_lengths划分标签后可分别表示这四张车牌:

tensor([18, 45, 33, 37, 40, 49, 63, 4, 54, 51, 34, 53, 37, 38, 22, 56, 37, 38,33, 39, 34, 46, 2, 41, 44, 37, 39, 35, 33, 40])

3)打印target_lengths如下,每个元素分别指定了按序取targets多少个元素来表示一个车牌即标签:

(7, 7, 8, 8)

我们划分targets后得到如下标签:

18, 45, 33, 37, 40, 49, 63  -->> 车牌 “湘E269JY”
4, 54, 51, 34, 53, 37, 38   -->> 车牌 “冀PL3N67”
22, 56, 37, 38,33, 39, 34, 46  -->> 车牌 “川R67283F”
2, 41, 44, 37, 39, 35, 33, 40  -->> 车牌 “津AD68429”

target_lengths元素数量的不同则表示了标签可变长。

4)打印input_lengths如下,由于输出序列长度T已经设定为18,因此其元素均是固定相同的:

(18, 18, 18, 18)

其中,只要模型配置固定了后,log_probs不需要我们组装再传送到CTCLoss,但是其余三个输入形参均需要我们根据实际数据集及C、T、N的情况进行设定!

三、需要注意的地方

3.1 官方所给的例程如下,但在实际应用中需要将log_probs的detach()去掉,否则无法反向传播进行训练;

>>> ctc_loss = nn.CTCLoss()
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
>>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()

3.2 blank空白标签一定要依据空白符在预测总字符集中的位置来设定,否则就会出错;

3.3 targets建议将其shape设为(sum(target_lengths)),然后再由target_lengths进行输入序列长度指定就好了,这是因为如果设定为(N, S),则因为S的标签长度如果是可变的,那么我们组装出来的二维张量的第一维度的长度仅为min(S)将损失一部分标签值(多维数组每行的长度必须一致),这就导致模型无法预测较长长度的标签;

3.4 输出序列长度T尽量在模型设计时就要考虑到模型需要预测的最长序列,如需要预测的最长序列其长度为I,则理论上T应大于等于2I+1,这是因为CTCLoss假设在最坏情况下每个真实标签前后都至少有一个空白标签进行隔开以区分重复项;

3.5 输出的log_probs除了进行log_softmax()处理再送入CTCLoss外,还必须要调整其维度顺序,确保其shape为(T, N, C)!

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个基于Python的科学计算库,主要针对深度学习任务。在PyTorch中,torch.nn是一个用于构建神经网络模型的模块。 torch.nn模块提供了一系列神经网络层和函数,方便用户构建自定义的神经网络。用户可以通过继承torch.nn.Module类来定义自己的神经网络模型。torch.nn模块中常用的类包括各种层(例如全连接层、卷积层、池化层和循环层等)、非线性激活函数和损失函数等。 在使用torch.nn模块构建神经网络时,用户需要实现模型的前向传播函数forward()。该函数定义了输入数据在神经网络中的流动方式,即通过层和函数的组合计算输出。在forward()函数中,用户可以使用已定义的层和函数进行计算,也可以实现自定义的操作。 torch.nn模块中的另一个重要概念是参数(parameter)。参数是模型中需要学习的变量,例如网络层的权重和偏置项。用户可以通过在模型中定义torch.nn.Parameter对象来创建参数,并在forward()函数中进行使用。 除了torch.nn模块外,PyTorch还提供了其他的工具和模块来辅助神经网络的训练和优化过程。例如torch.optim模块包含了各种优化算法,如随机梯度下降(SGD)、Adam等,用于更新模型中的参数。torch.utils.data模块提供了数据处理和加载的工具,方便用户使用自己的数据训练模型。 总之,torch.nn模块是PyTorch中用于构建神经网络模型的重要组成部分。通过使用torch.nn的各种类和函数,用户可以方便地创建自己想要的神经网络结构,并利用PyTorch强大的计算能力和优化算法来训练和优化模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值