CTC
最近在学习文字识别相关内容,这里记录一下CTC的实现过程,欢迎批评指正😊
CNN+RNN阶段提取特征就先不说了,本文只简要总结一下CTC原理和keras实现代码。
首先放上原文和解释CTC最好的几个文章:
CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇 https://xiaodu.io/ctc-explained/
CTC Algorithm Explained Part 2:Decoding the Network(CTC算法详解之解码篇 https://xiaodu.io/ctc-explained-part2/
一文读懂CRNN+CTC文字识别 https://zhuanlan.zhihu.com/p/43534801
CTC原理简要总结
一般来说我们需要输入和输出都是一一对应且标注好的,如果这样对于文本识别来说你不仅需要标注字符还需要标注位置,这是非常困难的,CTC的出现解决了这种对齐问题。CTC引入了空格字符‘-’,定义了一个多对一的β变换,举个例子:abbcd ,经过β变换可能输出abbcd的情况有-a-bb-b-cd, a-b-b-c-d-等等很多(但是是有限的,因为你的input_length是定值)。
定义x是输入,z是正确输出,我们希望p(z|x)尽可能的大,如果要计算p(z|x)我们可以选择计算所有能映射到abbcd的‘路径’。
这里补充解释一下图中CTC loss第一行,S是样本集合。实际训练时就是一个batch里的数据,我们希望这个batch所有数据预测对的概率最大,所以求最大似然(把每条数据预测对的概率累乘),这里求loss希望loss小所以取负对数似然。
直接暴力计算p(z|x)在识别较长序列时显然不现实,作者借鉴HMM的Forward-Backward算法思路采用动态规划的方法求解。
这里详情参考引用1
CTC训练bug解决
这里在提一下之前训练时遇到的一个bug:
原因是:网上说是label_legth>input_length,查询之后发现没有label_legth>input_length。实际是这样的,比如你的label是abb, input_length是3,那么必然报错,因为你至少输出ab-b四个字符才能得到abb。
解决方法:
- 修改CTC Loss里参数preprocess_collapse_repeated = True
- 过滤掉可能出现这种情况的数据
keras CTC实现代码
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.
# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
sparse_labels = tf.cast(
ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
y_pred = tf_math_ops.log(tf.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
return tf.expand_dims