pytorch的torch.nn.CTCLoss方法

CTC(Connectionist Temporal Classification)是一种解决神经网络输出与标签对齐问题的方法,尤其适用于序列数据如语音识别、场景文本识别。CTCLoss用于计算损失,其初始化需指定空白标签位置和损失处理方式。在训练过程中,log_probs、targets、input_lengths和target_lengths是关键参数。CTC Loss允许标签序列长度可变,提高了模型的灵活性。注意,空白标签的设置和损失计算方式需正确匹配。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、CTC说明

CTC的全称为Connectionist Temporal Classification,中文名称为:连接时序分类。这个方法主要是解决神经网络label和output不对齐的问题,其优点是不用强制对齐标签且标签可变长,仅需输入序列和监督标签序列即可进行训练。目前,该方法主要应用于场景文本识别、语音识别及手写字识别等工程场景。

怎么可以实现不对齐标签?定义一个多对一的映射B,目的是为了合并有相同输出的路径。举个例子,我们定一个规则:仅仅合并两个'-'间多余的字符并且移除所有的'-',那么:

B(a-ab-) = B(-aa--abb) = aab

也就是同一个输出有不同的路径。

二、CTCLoss接口使用说明

1、获取CTCLoss对象

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')

blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定

reduction:处理output losses的方式

2、在迭代中调用CTCLoss()对象计算损失值

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即预测模型输出序列长度(比如LPRNet算法预测模型输出长度为18),N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度。log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中。

targets:shape为(N, S)或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度;第二种类型,则为所有标签长度之和。但是需要注意的是targets不能包含有空白标签。

input_lengths: shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同。这个就是模型预测出来的长度,一般是固定的,比如LPRNet的长度为18。

target_lengths:shape为(N)的张量或元组,其每一个元素指示训练输入序列的标签长度,但标签长度是可以变化的。

3、举例说明

比如训练车牌的预测字符集如下,其中'-'表示空白标签:

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

因为空白标签所在的位置为len(CHARS) - 1,而我们需要处理CTCLoss output losses的方式为'mean',则需要按照如下方式初始化CTCLoss类:

ctc_loss = nn.CTCLosss(blank=len(CHARS)-1, reduction='mean')

我们设定输出序列长度T为18,训练批大小N为4且训练数据集仅有4张车牌如下,总的字符集长度C如上图CHARS所示为68:

那么在训练一次迭代中打印各个输入形参得出如下结果:

1)log_probs由于数值比较多且为神经网络前向输出结果,仅打印其shape出来,如下:

torch.Size([18, 4, 68])

2)打印targets如下,表示这四张车牌的训练标签,根据target_lengths划分标签后可分别表示这四张车牌:

tensor([18, 45, 33, 37, 40, 49, 63, 4, 54, 51, 34, 53, 37, 38, 22, 56, 37, 38,33, 39, 34, 46, 2, 41, 44, 37, 39, 35, 33, 40])

3)打印target_lengths如下,每个元素分别指定了按序取targets多少个元素来表示一个车牌即标签:

(7,7,8,8)

划分targets后得到如下标签:

18, 45, 33, 37, 40, 49, 63  -->> 车牌 “湘E269JY”
4, 54, 51, 34, 53, 37, 38   -->> 车牌 “冀PL3N67”
22, 56, 37, 38,33, 39, 34, 46  -->> 车牌 “川R67283F”
2, 41, 44, 37, 39, 35, 33, 40  -->> 车牌 “津AD68429”

target_lengtsh元素数量的不同则表示了标签可变长。

4)打印input_lengths如下,由于输出序列长度T已经设定为18,因此其元素均是固定相同的:

(18, 18, 18, 18)

4、注意事项

(1)ctcloss代码示例:

import torch
import torch.nn as nn

ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2).requires_grad_()
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
loss.backward()

(2)blank空白标签一定要依据空白符在预测总字符集中的位置来设定,否则就会出错。

(3)输出序列长度T尽量在模型设计时就要考虑模型需要预测的最长序列,如需要预测的最长序列其长度为l,则理论上T应大于等于2l+1,这是因为CTCLoss假设在最坏情况下每个真实标签前后都至少有一个空白标签进行隔开以区分重复项。

PyTorch中的CTCLoss是指Connectionist Temporal Classification Loss,它是一种用于解决神经网络标签和输出不对齐问题的方法CTCLoss的优点是不需要强制对齐标签且标签可以是可变长度的。它主要应用于场景文本识别、语音识别和手写字识别等工程场景。在PyTorch 1.0.x版本内,已经内置了CTCLoss接口,可以直接使用。下面是一个使用CTCLoss的代码示例: ```python import torch import torch.nn as nn ctc_loss = nn.CTCLoss() log_probs = torch.randn(50, 16, 20).log_softmax(2).requires_grad_() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16,), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) loss.backward() ``` 在这个示例中,我们首先创建了一个CTCLoss实例,然后生成了一些随机的log probabilities作为网络的输出。接着,我们生成了一些随机的目标标签和输入长度以及目标长度。最后,我们使用CTCLoss计算了损失,并进行了反向传播。\[2\] 在创建CTCLoss实例时,可以通过设置参数来自定义一些属性。例如,可以使用`blank`参数来指定空白符的序号,`reduction`参数来指定损失的计算方式。\[3\] 希望这个回答对你有帮助! #### 引用[.reference_title] - *1* [如何使用pytorch内置torch.nn.CTCLoss方法&&车牌识别应用](https://blog.csdn.net/CSDNwei/article/details/120223026)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [pytorchtorch.nn.CTCLoss方法](https://blog.csdn.net/benben044/article/details/125130411)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Pytorch中的CTC loss](https://blog.csdn.net/fidbdiej/article/details/124587812)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值