focal loss在ocr seq2seq中文文字识别模型中应用思考

一、简介

       focal loss是有微软resnet的研究者何恺明提出的,论文地址(https://arxiv.org/pdf/1708.02002.pdf),其主要是为了解决目标检测中目标与背景样本不均衡问题而提出的。ocr中文文字识别同样有样本不均衡问题,一些生僻字经常会识别错误。而seq2seq是一种基于时序的模型,在ocr中文识别中也会应用到。

      本文主要介绍用focal loss进行基于seq2seq的中文文字识别的一些失败应用及原因分析。

二、基于focal loss的ocr seq2seq模型训练

       focal loss的损失函数公式如下:

     ocr seq2seq采用的模型代码链接如下:https://github.com/clovaai/deep-text-recognition-benchmark

     经过对不同 和调参训练发现focal loss的模型该开始拟合较快,但随着训练进程,过拟合越来越严重,验证集精度没有ce loss提升更快,最后放弃训练。

三、过拟合原因分析

       对于focal loss的介绍,不得不从其对比的ce loss(cross entropy loss)说起。下面以逻辑回归为例,说明ce loss是普通情况下的最优损失函数:

      当求A对于Bi类的概率相当于求p(Bi|A),根据贝叶斯公式可求:

      其中A为x与wi的协方差。P(ABi)就是假定无噪声的x与wi线性相关,且有噪声的x与wi的协方差符合正态概率分布。我们要做的是最大化P(ABi)的概率,就是最大化x与x是Bi类同时发生的概率。由于假定x的噪声符合正态分布,所以ce loss是假定x与w线性相关下最普遍的损失函数。

       focal loss的函数图像如下

    

     当判断的概率越低,对该正确类型的梯度越大,对后面参数的调整越大。是一种对ce loss的改进。

     但基于seq2seq的模型,在训练及预测过程中会考虑P("人"|“我是中国”)*P("人")概率,而“人”的梯度过大,也会增大"人"|"我是中国"时序网络的梯度,使模型更容易记住训练集中语言模式,而不是通过读文字记住图像中的文字。

     

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值