OCR文字识别之CTC原理和实现

本文简要介绍了CTC(Connectionist Temporal Classification)算法在文字识别中的应用,包括其解决对齐问题的作用、CTC损失函数的解释、训练中的常见问题及解决方案,以及Keras中CTC损失函数的实现细节。通过学习,读者可以理解CTC如何处理序列到序列的非对齐映射,并掌握其在实际操作中的应用。
摘要由CSDN通过智能技术生成

CTC

最近在学习文字识别相关内容,这里记录一下CTC的实现过程,欢迎批评指正😊

CNN+RNN阶段提取特征就先不说了,本文只简要总结一下CTC原理和keras实现代码。
首先放上原文和解释CTC最好的几个文章:

CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇 https://xiaodu.io/ctc-explained/
CTC Algorithm Explained Part 2:Decoding the Network(CTC算法详解之解码篇 https://xiaodu.io/ctc-explained-part2/
一文读懂CRNN+CTC文字识别 https://zhuanlan.zhihu.com/p/43534801

CTC原理简要总结
一般来说我们需要输入和输出都是一一对应且标注好的,如果这样对于文本识别来说你不仅需要标注字符还需要标注位置,这是非常困难的,CTC的出现解决了这种对齐问题。CTC引入了空格字符‘-’,定义了一个多对一的β变换,举个例子:abbcd ,经过β变换可能输出abbcd的情况有-a-bb-b-cd, a-b-b-c-d-等等很多(但是是有限的,因为你的input_length是定值)。

定义x是输入,z是正确输出,我们希望p(z|x)尽可能的大,如果要计算p(z|x)我们可以选择计算所有能映射到abbcd的‘路径’。

这里补充解释一下图中CTC loss第一行,S是样本集合。实际训练时就是一个batch里的数据,我们希望这个batch所有数据预测对的概率最大,所以求最大似然(把每条数据预测对的概率累乘),这里求loss希望loss小所以取负对数似然。
直接暴力计算p(z|x)在识别较长序列时显然不现实,作者借鉴HMM的Forward-Backward算法思路采用动态规划的方法求解。

这里详情参考引用1

CTC训练bug解决
这里在提一下之前训练时遇到的一个bug:

原因是:网上说是label_legth>input_length,查询之后发现没有label_legth>input_length。实际是这样的,比如你的label是abb, input_length是3,那么必然报错,因为你至少输出ab-b四个字符才能得到abb。
解决方法:

  1. 修改CTC Loss里参数preprocess_collapse_repeated = True
  2. 过滤掉可能出现这种情况的数据

keras CTC实现代码

def ctc_batch_cost(y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
        y_true: tensor `(samples, max_string_length)`
            containing the truth labels.
        y_pred: tensor `(samples, time_steps, num_categories)`
            containing the prediction, or output of the softmax.
        input_length: tensor `(samples, 1)` containing the sequence length for
            each batch item in `y_pred`.
        label_length: tensor `(samples, 1)` containing the sequence length for
            each batch item in `y_true`.

    # Returns
        Tensor with shape (samples,1) containing the
            CTC loss of each element.
    """
    label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
    input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
    sparse_labels = tf.cast(
        ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
    y_pred = tf_math_ops.log(tf.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
    return tf.expand_dims
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值