CTCloss从理论到训练

CTCloss

首先来介绍比较复杂抽象的CTCloss。
先上大神的链接
CTC 的全称是Connectionist Temporal Classification,中文名称是“连接时序分类”,这个方法主要是解决神经网络label 和output 不对齐的问题(Alignment problem),其优点是不用强制对齐标签且标签可变长,仅需输入序列和监督标签序列即可进行训练
首先简单说一下CTCLoss的应用场景,适用于文字识别,验证码识别,手写数字识别,语音识别等领域。
为什么呢?这就是由于CTCLoss的原理来决定的了。
今天的讲解是基于CRNN这个网路结构来讲解CTCloss的。其实Lipnet也是一个CRNN的网络,后面详细讲解。这就是CRNN的网络结构,cnn+rnn+ctcloss,这是处理文字模型常用的一种方案之一。
在这里插入图片描述
在这里插入图片描述
然后经过cnn网络输出(1,25,512)大小的卷积特征网络,然后再经过rnn,这里的rnn一般有两种:
一种是深层双向的RNN网络
在这里插入图片描述
还有一种就是stack-双向RNN网络
在这里插入图片描述
至于这两种网络的具体网络细节,我会在之后的博文中详细写到(大家可以关注,点赞一波。哈哈)
在这里插入图片描述
在这里插入图片描述
我这里是这样给出的,具体根据实际情况来。

到这里基本的网络结构大家心中就已经有数了,但是问题也随之而来。文字识别,语音输入都有一个问题就是,叠字(叠声)还有空字(空声,也就是有的人说话快,有的人说话慢)的问题。这样的话我们该怎么解决呢?总不能逐帧标注吧,这样工作量也太大了!!

先给出流程图:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
t只不过是一个时间时刻,在这一个时间时刻有一个输入图像,所以这里每一个xt,yt都对应有一列输入,这一列的长度取决于你的label数量。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
接下来就是CTCLOSS的部分了。
在这里插入图片描述
在这里插入图片描述
这里的T是根据你的输入数据自行决定的,一般去你输入数据的最大序列,所以有很多条路径,我们想要计算出所有的路径再求argmax显然是不太现实的,所以需要一种快速的计算方法。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这就是我们的最终结论,写的很详细,多看两遍就懂了。
下面讲一下怎么训练CTC:
在这里插入图片描述
具体求导过程如下所示(这里为了计算简便,对loss函数取对数):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
讲到这里CTC的理论知识基本就讲完了,在tensorflow中和pytorch(1.1以后版本)中都有内置的CTCloss函数。
下面是tensorflow的接口,有三个参数:
tf.nn.ctc_loss( labels, inputs, sequence_length, preprocess_collapse_repeated=False, ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False, time_major=True )
pytorch的CTCloss有四个接口:

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')#blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;

#reduction:处理output losses的方式,string类型,可选'none' 、 'mean' 及 'sum','none'表示对#output losses不做任何处理,'mean' 则对output losses取平均值处理,'sum'则是对output losses求和处#理,默认为'mean' 。

ctc_loss(log_probs, targets, input_lengths, target_lengths)

log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;

targets:shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;

input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;

target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;
具体的pytorch中CTCloss的操作可以查看这个链接

  • 18
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值