一、简介
focal loss是有微软resnet的研究者何恺明提出的,论文地址(https://arxiv.org/pdf/1708.02002.pdf),其主要是为了解决目标检测中目标与背景样本不均衡问题而提出的。ocr中文文字识别同样有样本不均衡问题,一些生僻字经常会识别错误。而seq2seq是一种基于时序的模型,在ocr中文识别中也会应用到。
本文主要介绍用focal loss进行基于seq2seq的中文文字识别的一些失败应用及原因分析。
二、基于focal loss的ocr seq2seq模型训练
focal loss的损失函数公式如下:
ocr seq2seq采用的模型代码链接如下:https://github.com/clovaai/deep-text-recognition-benchmark
经过对不同 和调参训练发现focal loss的模型该开始拟合较快,但随着训练进程,过拟合越来越严重,验证集精度没有ce loss提升更快,最后放弃训练。
三、过拟合原因分析
对于focal loss的介绍,不得不从其对比的ce loss(cross entropy loss)说起。下面以逻辑回归为例,说明ce loss是普通情况下的最优损失函数:
当求A对于Bi类的概率相当于求p(Bi|A),根据贝叶斯公式可求:
其中A为x与wi的协方差。P(ABi)就是假定无噪声的x与wi线性相关,且有噪声的x与wi的协方差符合正态概率分布。我们要做的是最大化P(ABi)的概率,就是最大化x与x是Bi类同时发生的概率。由于假定x的噪声符合正态分布,所以ce loss是假定x与w线性相关下最普遍的损失函数。
focal loss的函数图像如下
当判断的概率越低,对该正确类型的梯度越大,对后面参数的调整越大。是一种对ce loss的改进。
但基于seq2seq的模型,在训练及预测过程中会考虑P("人"|“我是中国”)*P("人")概率,而“人”的梯度过大,也会增大"人"|"我是中国"时序网络的梯度,使模型更容易记住训练集中语言模式,而不是通过读文字记住图像中的文字。