一、聚焦损失函数和平衡因子基础知识
- 从分类错误代价和样本困难程序两个方面思考,有时间再写。
二、基于平衡因子的聚焦损失函数的python(torch)类实现
- 该代码中增加了对nan的调试代码,可以用于快速寻找什么地方导致了nan
- 代码中增加了变异系数的计算,可以方便查看在算法迭代过程中各类样本loss的分布情况,返回差异系数越大,表示类别之间的loss差异越大,越不均衡。最好的情况为:算法迭代前后,差异系数值变小。
class FocalLoss(nn.Module):
def __init__(self, gamma=4.5, alpha=0.05):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if gamma==0:
self.x = 0
else:
self.x = 1
def forward