LPRNetN 车牌识别会用到
CTCLoss
batch_size 16
这个入门也不错:
如何使用pytorch内置torch.nn.CTCLoss的方法&&车牌识别应用_CSDNwei的专栏-CSDN博客_pytorch 车牌识别
import torch
import torch.nn as nn
ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
print(loss)
# loss.backward()
以下转自:
如何优雅的使用pytorch内置torch.nn.CTCLoss的方法 - 知乎
二、CTCLoss接口使用说明
第一步,获取CTCLoss()对象
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
类初始化参数说明:
blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;
reduction:处理output losses的方式,string类型,可选'none' 、 'mean' 及 'sum','none'表示对output losses不做任何处理,'mean' 则对output losses取平均值处理,'sum'则是对output losses求和处理,默认为'mean' 。
第二步,在迭代中调用CTCLoss()对象计算损失值
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
CTCLoss()对象调用形参说明:
log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;
targets:shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;
input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;
target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;
举个具体例子说明如何使用CTCLoss(),如下为CTCLoss在车牌识别里面的应用:
比如我们需要预测的字符集如下,其中'-'表示空白标签;
CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
'新',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-'
]
因为空白标签所在的位置为len(CHARS)-1,而我们需要处理CTCLoss output losses的方式为‘mean’,则需要按照如下方式初始化CTCLoss类:
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
我们设定输出序列长度T为18,训练批大小N为4且训练数据集仅有4张车牌(为了方便说明)如下,总的字符集长度C如上面CHARS所示为68:
那么我们在训练一次迭代中打印各个输入形参得出如下结果:
1)log_probs由于数值比较多且为神经网络前向输出结果,我们仅打印其shape出来,如下:
torch.Size([18, 4, 68])
2)打印targets如下,表示这四张车牌的训练标签,根据target_lengths划分标签后可分别表示这四张车牌:
tensor([18, 45, 33, 37, 40, 49, 63, 4, 54, 51, 34, 53, 37, 38, 22, 56, 37, 38,33, 39, 34, 46, 2, 41, 44, 37, 39, 35, 33, 40])
3)打印target_lengths如下,每个元素分别指定了按序取targets多少个元素来表示一个车牌即标签: