时间片轮转源码_CTC loss 笔记+源码分享

这篇博客详细介绍了CTC(Connectionist Temporal Classification)损失函数,它用于解决序列到序列任务中的对齐问题,尤其适用于OCR和语音识别。CTC通过引入空白字符解决连续字符重复和不同长度序列对齐的挑战,并使用动态规划计算概率路径。文章讨论了对齐策略、路径搜索、简化计算以及logsumexp的优化,同时分享了CTC损失的PyTorch源码分析。
摘要由CSDN通过智能技术生成

简介

在ocr任务与机器翻译中,输入与输出GT文本很难在单词上对齐,在预处理的时候对齐是非常困难的,但是如果不对齐而直接训练模型的话,由于字符距离的不同,导致模型很难收敛.

CTC(Connectionist Temporal Classification)避免了输入与输出手动对齐,适合OCR或语音这样的序列应用;

建模

给定输入序列

,以及对应的标签数据
.
不一定相等 我们的工作是找到一个X到Y的映射.这种对时序数据进行分类算法叫做Temporal Classification。

对比传统分类方法,时序分类有一下困难:

  • X和Y的长度都是变化且不相等的.
  • 对于一个端到端的模型,我们并不想手动设计X和Y之间的对齐.

CTC提供了解决方案,对于一个给定的输入序列

,CTC给出所有可能的 Y 的输出分布。根据这个分布,我们可以输出最可能的结果或者给出某个输出的概率。

loss:给定输入序列X,我们希望最大化Y的后验概率

,
应该是可导的,这样我们就能执行梯度下降算法进行优化.

infer:

1.1 对齐

在ocr任务中,输入X是一张含有"CAT"的图片,输出Y是文本[C,A,T] 最原始的对齐方式将X分割成若干个时间片,每一个时间片得到一个字符的输出,然后合并连续重复出现的字符.
然而这样做有两个缺点:

  • 几乎不可能将 X 的每个时间片都和输出Y对应上,例如OCR中字符的间隔,语音识别中的停顿;
  • 不能处理有连续重复字符出现的情况,例如单词“APPLE”,按照上面的算法,输出的是“APLE”而非“APPLE”。

为了解决上面的问题,CTC引入了空白字符, CTC的对齐涉及去除重复字母和去除 空白字符 两部分. 其规则:

  • 连续相同的字符做去重
  • 去重空白字符

比如,对于长度为10的输入序列,以下RNN输出序列都可以映射为: apple

  • _aappp_ple
  • ap_p_|_ e
  • _ _ app_ple_

最后要计算P(Y|X),可以累加其对应的全部输出(全部输出为apple)的路径概率之和.

因此,在训练阶段,我们要对GT进行标签扩充.其做法是: 头尾加空白符,并在GT中的每一个字符间插入空白符; 用l表示最终标签,l’表示扩展后的形式, 则由2|l| + 1 = 2|l’|,比如:l=apple => l’=_a_p_p_l_e_

如图所示:

b0f9394b9a57da7e5a3affcc3b4c3b4f.png

1.2 路径搜索与动态规划

37671cfbce8d174fcc33cd1fb1ac5d75.png

b18114cc41a6b10ae2f1d56fc79fba76.png

05ad8dccd5f038e571f3dae5e27e01f4.png

上图的路径搜索中: 定义:

  • 为第t个时刻,gt字符串
    的第s个字符的路径前向概率.
  • 为预测矩阵中第t时刻是第s个字符的概率.
  • 为输入序列x,输出为l的概率,我们要最大化其概率

(1) 如果

为空白符.则
只能由前一个空白符
或者其GT中该字符为上一时刻
得到(因为我们是隔一个字符插入空白符,当前字符是空白的话,如果前一个也不是s的字符,就会错过GT中s字符,导致最后的path没法解析到GT,所以要么最多连续两个空白格,要么是前一个已经出现s字符,当前可以为空白)

所以这种情况下其概率为:

(2)如果

不为空白符,那么该点的前向概率之和可以通过以下路径得到
  • :s为当前gt的第s个字符,然后前一个为空白字符的概率
  • :当前字符s连续出现的概率
  • :前一个字符GT的上一个字符,当前步是GT的S个字符,代表s-1,s之间没有空白符,没有连续的字符的概率(因为每个字符都隔了一个空白符)

所以这种情况下前向概率为:

初始化值:

(代表了T时刻可以从空白符出发,也可以从gt的第0个字符开始)

最后我们需要计算

两个前向概率之和便得到前向概率之和.如图的右下角两个位置概率之和.

利用前向概率计算ctc 的loss 即

等于最小化对数域.

所以loss的值为:

简化计算

我们看到在计算过程中我们发现了大量的连乘。由于每一个数字都是浮点数,那么这样连乘下去,最终数字有可能非常小而导致underflow。所以我们要将这个计算过程转到对数域上。这样我们就将其中的乘法转变成了加法。

由于最后计算loss为-ln(P)

所以在计算前向概率的时候可以直接计算log(p)的值..

logsumexp的优化

如果我们有N个概率,

,我们想求其对数域之和:

如果

很大或很小,朴素的直接计算会上溢出或下溢出,从而导致严重问题。举个例子,对于[0,1,0]直接计算是可行的,我们可以得到1.55。但对于[1000,1001,1000]却并不可行,我们会得到inf;对于[-1000,-999,-1000],还是不行,我们会得到-inf.

解决方法:

一般情况下,a取N个值中的最大值; 这可以保证指数最大不会超过0,于是你就不会上溢出。即便剩余的部分下溢出了,你也能得到一个合理的值。

证明:

CTC 概率图前向概率:

代表了第t时刻第s个gt字符(经过补充空白符)的概率

在torch的ctc 前馈过程中,计算的log前向概率值的矩阵,(用以进行loss back),我们看到其核心:

  • 每一个baych ,通过两层循环(T,S)动态规划计算前向概率Log值. 在计算的同时将同样需要计算的
    作为la1 ,la2,然后判断当前s的字符来决定第三个加项是否为
  • 同时,因为转换到了对数域,也避免了数变小与溢出的问题,其项变成了
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];

源码分析: 使用了pytorch中ctcloss 的源码

//pytorch/aten/src/ATen/native/LossCTC.cpp

//获取填充blank后的target指定位置的值,用来判断是
static inline int64_t get_target_prime(target_t *target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK)
            {
                if (idx % 2 == 0)
                {
                    return BLANK;
                }
                else
                {
                    return target[offset + stride * (idx / 2)];
                }
            }

//ctc_loss_cpu_template部分核心代码
//前向概率[B,T,N]
                Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, log_probs.options());

                Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());

                //[B,T,N]
                auto lpp = log_probs.permute({1, 0, 2});
                auto log_probs_a_global = lpp.accessor<scalar_t, 3>();
                auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
                auto targets_data = targets.data_ptr<target_t>();
                auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();

                // alpha calculation for the first row, the three equations for alpha_1 above eq (6)
                // first the default
                log_alpha.narrow(1, 0, 1).fill_(neginf);
                at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
                    for (int64_t b = start; b < end; b++)
                    {
                        //每个batch
                        int64_t input_length = input_lengths[b];
                        int64_t target_length = target_lengths[b];
                        auto log_probs_a = log_probs_a_global[b];
                        auto log_alpha_a = log_alpha_a_global[b];
                        int64_t tg_batch_offset = tg_batch_offsets[b];
                        
                        // the first two items of alpha_t above eq (6)
                        //初始化前向概率[t0][s0]
                        log_alpha_a[0][0] = log_probs_a[0][BLANK];
                        if (target_length > 0)
                            //[t0][s1]等于序列中第一个字符的概率
                            log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];

                        // now the loop over the inputs
                        for (int64_t t = 1; t < input_length; t++)
                        {
                            for (int64_t s = 0; s < 2 * target_length + 1; s++)
                            {
                                //对于每一个s,计算其概率
                                //获取第s个字符是什么
                                auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
                                // this loop over s could be parallel/vectorized, too, but the required items are one index apart
                                // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
                                // for the cuda implementation, that gave a speed boost.
                                // This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.

                                scalar_t la1 = log_alpha_a[t - 1][s];
                                scalar_t lamax = la1;
                                scalar_t la2, la3;
                                if (s > 0)
                                {
                                    
                                    la2 = log_alpha_a[t - 1][s - 1];
                                    if (la2 > lamax)
                                        lamax = la2;
                                }
                                else
                                {
                                    la2 = neginf;
                                }
                                if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s - 2, BLANK) !=
                                                current_target_prime))
                                {
                                    //第s个字符不是空且不等于s-2(即不连续的时候),即动态转移方程的第二个式子
                                    la3 = log_alpha_a[t - 1][s - 2];
                                    if (la3 > lamax)
                                        lamax = la3;
                                }
                                else
                                {   
                                    //s为空或者连续,第三项不用加
                                    la3 = neginf;
                                }
                                //添加概率最大项按前一个[t-1][s]
                                if (lamax == neginf) // cannot do neginf-neginf
                                    lamax = 0;

                                //计算此时的
                                // this is the assignment of eq (6)
                                log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
                            }
                        }
                        // the likelihood is the the sum of the last two alphas, eq (8), the loss is the negative log likelihood
                        if (target_length == 0)
                        {
                            // if the target is empty then there is no preceding BLANK state and hence there is no path to merge
                            neg_log_likelihood_a[b] = -log_alpha_a[input_length - 1][0];
                        }
                        else
                        {
                            scalar_t l1 = log_alpha_a[input_length - 1][target_length * 2];
                            scalar_t l2 = log_alpha_a[input_length - 1][target_length * 2 - 1];
                            //取两条路的概率之和
                            scalar_t m = std::max(l1, l2);
                            m = ((m == neginf) ? 0 : m);
                            scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
                            neg_log_likelihood_a[b] = -log_likelihood;
                        }
                    }
                });

提供一个python 版本的numpy ctc的代码方便理解

import numpy as np


ninf = -np.float('inf')

def _logsumexp(a, b):
    '''
    np.log(np.exp(a) + np.exp(b))
    '''

    if a < b:
        a, b = b, a

    if b == ninf:
        return a
    else:
        return a + np.log(1 + np.exp(b - a)) 

def logsumexp(*args):
    '''
    from scipy.special import logsumexp
    logsumexp(args)
    '''
    res = args[0]
    for e in args[1:]:
        res = _logsumexp(res, e)
    return res
class CTC:
    def __init__(self):
        pass

    def forward(self):
        pass

    def alpha(self, log_y, labels):
        ##alpha 为前向概率
        T, V = log_y.shape
        L = len(labels)
        log_alpha = np.ones([T, L]) * ninf

        # init
        ## 初始化动态规划
        log_alpha[0, 0] = log_y[0, labels[0]]
        log_alpha[0, 1] = log_y[0, labels[1]]

        ##计算每一步,每个GT的前向概率
        for t in range(1, T):
            for i in range(L):
                s = labels[i]

                a = log_alpha[t - 1, i]
                if i - 1 >= 0:
                    a = logsumexp(a, log_alpha[t - 1, i - 1])
                ##如果当前不是空白符,得加前两步的状态
                if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
                    a = logsumexp(a, log_alpha[t - 1, i - 2])

                log_alpha[t, i] = a + log_y[t, s]

        return log_alpha


    def beta(self, log_y, labels):
        ##计算后向概率
        T, V = log_y.shape
        L = len(labels)
        log_beta = np.ones([T, L]) * ninf

        # init
        log_beta[-1, -1] = log_y[-1, labels[-1]]
        log_beta[-1, -2] = log_y[-1, labels[-2]]

        for t in range(T - 2, -1, -1):
            for i in range(L):
                s = labels[i]

                a = log_beta[t + 1, i]
                if i + 1 < L:
                    a = logsumexp(a, log_beta[t + 1, i + 1])
                if i + 2 < L and s != 0 and s != labels[i + 2]:
                    a = logsumexp(a, log_beta[t + 1, i + 2])

                log_beta[t, i] = a + log_y[t, s]

        return log_beta

    def backward(selflog_y, labels):
        T, V = log_y.shape
        L = len(labels)

        log_alpha = self.alpha(log_y, labels)
        log_beta = self.beta(log_y, labels)
        log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])
        ##任意时刻的
        log_grad = np.ones([T, V]) * ninf
        for t in range(T):
            for s in range(V):
                lab = [i for i, c in enumerate(labels) if c == s]
                for i in lab:
                    log_grad[t, s] = logsumexp(log_grad[t, s],
                                            log_alpha[t, i] + log_beta[t, i])
                log_grad[t, s] -= 2 * log_y[t, s]

        log_grad -= log_p
        return log_grad

    def predict(self):
        pass

    def ctc_prefix(self):
        pass

    def ctc_beamsearch(self):
        pass

    def alpha_vanilla(self, y, labels):
        T, V = y.shape  # T,time step, V: probs
        L = len(labels) # label length
        alpha = np.zeros([T, L])

        # init
        alpha[0, 0] = y[0, labels[0]]
        alpha[0, 1] = y[0, labels[1]]

        for t in range(1, T):
            for i in range(L):
                s = labels[i]

                a = alpha[t - 1, i]
                if i - 1 >= 0:
                    a += alpha[t - 1, i - 1]
                if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
                    a += alpha[t - 1, i - 2]

                alpha[t, i] = a * y[t, s]

        return alpha

    def beta_vanilla(self, y, labels):
        ##原始版计算前向概率,没在对数域中计算
        T, V = y.shape
        L = len(labels)
        beta = np.zeros([T, L])

        # init
        beta[-1, -1] = y[-1, labels[-1]]
        beta[-1, -2] = y[-1, labels[-2]]

        for t in range(T - 2, -1, -1):
            for i in range(L):
                s = labels[i]

                a = beta[t + 1, i]
                if i + 1 < L:
                    a += beta[t + 1, i + 1]
                if i + 2 < L and s != 0 and s != labels[i + 2]:
                    a += beta[t + 1, i + 2]

                beta[t, i] = a * y[t, s]

        return beta

    def gradient(self, y, labels):
        T, V = y.shape
        L = len(labels)

        alpha = self.alpha_vanilla(y, labels)
        beta = self.beta(y, labels)
        p = alpha[-1, -1] + alpha[-1, -2]

        grad = np.zeros([T, V])
        for t in range(T):
            for s in range(V):
                lab = [i for i, c in enumerate(labels) if c == s]
                for i in lab:
                    grad[t, s] += alpha[t, i] * beta[t, i]
                grad[t, s] /= y[t, s] ** 2

        grad /= p
        return grad



 def check_grad(y, labels, w=-1, v=-1, toleration=1e-3):
    grad_1 = gradient(y, labels)[w, v]

    delta = 1e-10
    original = y[w, v]

    y[w, v] = original + delta
    alpha = forward(y, labels)
    log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])

    y[w, v] = original - delta
    alpha = forward(y, labels)
    log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])

    y[w, v] = original

    grad_2 = (log_p1 - log_p2) / (2 * delta)
    if np.abs(grad_1 - grad_2) > toleration:
        print('[%d, %d]:%.2e' % (w, v, np.abs(grad_1 - grad_2)))


def remove_blank(labels, blank=0):
    new_labels = []

    # combine duplicate
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l

    # remove blank
    new_labels = [l for l in new_labels if l != blank]

    return new_labels

def insert_blank(labels, blank=0):
    new_labels = [blank]
    for l in labels:
        new_labels += [l, blank]
    return new_labels

def greedy_decode(y, blank=0):
    raw_rs = np.argmax(y, axis=1)
    rs = remove_blank(raw_rs, blank)
    return raw_rs, rs

def beam_decode(y, beam_size=10):
    T, V = y.shape
    log_y = np.log(y)

    beam = [([], 0)]
    for t in range(T):  # for every timestep
        new_beam = []
        for prefix, score in beam:
            for i in range(V):  # for every state
                new_prefix = prefix + [i]
                new_score = score + log_y[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    return beam

def prefix_beam_decode(y, beam_size=10, blank=0):
    T, V = y.shape
    log_y = np.log(y)

    beam = [(tuple(), (0, ninf))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda : (ninf, ninf))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_y[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x : logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值