【炼丹术】——Focal Loss的理解

1. 前言

Focal Loss最初是由Kaiming大神在Focal Loss for Dense Object Detection一文中提出的,旨在解决目标检测中的数据类别不平衡造成的模型性能问题,也常用于NLP领域。

本质上,Focal Loss是解决分类问题中类别不均衡、分类难度差异的一个损失函数。

2. 细节

2.1 交叉熵损失函数

C E ( p , y ) = { − l o g ( p ) , y = 1 − l o g ( 1 − p ) , y = o t h e r w i s e CE(p,y)=\left\{ \begin{matrix} -log(p), y=1 \\ -log(1-p) ,y=otherwise \end{matrix} \right. CE(p,y)={log(p),y=1log(1p),y=otherwise
令:
p t = { p , y = 1 1 − p , y = o t h e r w i s e p_t=\left\{ \begin{matrix} p, y=1 \\ 1-p ,y=otherwise \end{matrix} \right. pt={p,y=11p,y=otherwise
所以:
C E ( p , y ) = C E ( p t ) = − l o g ( p t ) CE(p,y)=CE(p_t)=-log(p_t) CE(p,y)=CE(pt)=log(pt)

2.2 样本不平衡

对所有样本,其损失函数为:
L = 1 N ∑ i = 1 N l ( y i , p ^ i ) L=\frac{1}{N}\sum_{i=1}^Nl(y_i,\hat p_i) L=N1i=1Nl(yi,p^i)
对于二分类问题,损失函数为:
L = 1 N ( ∑ y i = 1 m − l o g ( p ^ ) + ∑ y i = 0 n − l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^m-log(\hat p)+\sum_{y_i=0}^n-log(1-\hat p)) L=N1(yi=1mlog(p^)+yi=0nlog(1p^))
其中m为正样本个数,n为负样本个数,N为样本总数, N = m + n N=m+n N=m+n,当样本分布失衡时损失函数的分布会发生倾斜(如 m < < n m<<n m<<n时,负样本的损失就会占据损失的主要部分)。由于损失函数倾斜,模型训练过程中会倾向于样本多的类别,从而造成模型对少样本类别的性能差。

2.3 balanced cross entropy

balanced cross entropy平衡交叉熵函数,该函数为交叉熵损失函数增加一个权重因子,用来调整损失函数分布。公式如下:
C E ( p t ) = − α t l o g ( p t ) CE(p_t)=-\alpha _tlog(p_t) CE(pt)=αtlog(pt)
α \alpha α是超参数,一般类别样本数量越多 α \alpha α值越小。

2.4 focal loss

balanced cross entropy不同的是:focal loss是从loss的角度解决样本不均衡问题,其公式如下:
F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) FL(p_t)=-(1-p_t)^\gamma log(p_t) FL(pt)=(1pt)γlog(pt)
其中 γ > 0 \gamma >0 γ>0,是调整因子。当 γ = 0 \gamma =0 γ=0时,focal loss等价于corss entorypy。如下图所示:
在这里插入图片描述

3. 特点

( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ是调制因子(modulating factor),从以上公式可得出如下推论:

  1. p t p_t pt趋于0的时候(样本分类错误,属于难分类样本),调制因子趋于1,该部分损失在总loss中基本不受影响。当 p t p_t pt趋于1的时候(样本分类正确,属于易分类样本),调制因子趋于0,该部分损失在总loss中的权重变小。
  2. 参数 γ \gamma γ平滑的降低易分类样本损失在总损失的比例,使样本更加专注于学习难分类样本的特征。当 γ = 0 \gamma =0 γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整 γ \gamma γ实现调制因子的改变。

4. 编码

class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"    
    def __init__(self, alpha=.25, gamma=2):
            super(WeightedFocalLoss, self).__init__()        
            self.alpha = torch.tensor([alpha, 1-alpha]).cuda()        
            self.gamma = gamma
            
    def forward(self, inputs, targets):
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')        
            targets = targets.type(torch.long)        
            at = self.alpha.gather(0, targets.data.view(-1))        
            pt = torch.exp(-BCE_loss)        
            F_loss = at*(1-pt)**self.gamma * BCE_loss        
            return F_loss.mean()
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是 P1648 炼丹术 的 C++ 代码,包括注释和解释: ```cpp #include <iostream> #include <cstdio> #include <algorithm> #include <cstring> #include <queue> using namespace std; const int N = 1005, M = 20005; // N 表示点数,M 表示边数 int n, m, k, len; // n 表示药剂的数量,m 表示关系的数量,k 表示需要合成的数量,len 表示药剂名字的长度 int start, end; // start 表示起点,end 表示终点 int h[N], e[M], ne[M], w[M], idx; // 邻接表存图,h 存每个点的头结点,e 存每条边的终点,ne 存每条边的下一条边的编号,w 存每条边的权值,idx 表示边的编号 int dist[N]; // 存储每个点到起点的最短距离 bool st[N]; // 存储每个点是否在队列中 char name[N][15]; // 存储每个药剂的名字 struct Node { // 存储每个药剂的信息 int id; // id 表示药剂的编号 int cost; // cost 表示合成这个药剂的代价 int pre[5]; // pre 数组存储合成这个药剂需要的前置药剂编号 } node[N]; void add(int a, int b, int c) { // 添加一条边 e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx ++; } bool check(int s) { // 判断药剂 s 是否可以被合成 if (node[s].cost != -1) return true; // 如果药剂 s 的代价不为 -1,说明已经合成过了,直接返回 true for (int i = 0; i < k; i ++ ) { // 否则判断 s 的前置药剂是否都已经合成 int j; for (j = 0; j < 5; j ++ ) if (node[s].pre[j] != -1 && node[node[s].pre[j]].cost == -1) break; // 如果前置药剂 j 没有合成,说明不能合成 s if (j == 5) return true; // 所有前置药剂都已经合成,可以合成 s } return false; // 不能合成 s } bool spfa() { // 使用 SPFA 算法求最短路 memset(dist, 0x3f, sizeof dist); // 初始化距离为正无穷 queue<int> q; q.push(start); dist[start] = 0; st[start] = true; while (q.size()) { int t = q.front(); q.pop(); st[t] = false; for (int i = h[t]; ~i; i = ne[i]) { // 遍历 t 的所有邻接点 int j = e[i]; // j 表示 t 的一个邻接点 if (check(j)) { // 如果 j 可以被合成 if (dist[j] > dist[t] + w[i]) { // 如果从 t 到 j 的距离更短,更新距离 dist[j] = dist[t] + w[i]; if (!st[j]) { // 如果 j 不在队列中,加入队列 q.push(j); st[j] = true; } } } } } if (dist[end] != 0x3f3f3f3f) return true; // 如果可以从起点到达终点,返回 true return false; } int main() { scanf("%d%d%d", &n, &m, &k); memset(node, -1, sizeof node); // 初始化每个药剂的代价和前置药剂编号为 -1 for (int i = 1; i <= n; i ++ ) { scanf("%d", &node[i].cost); scanf("%d", &len); scanf("%s", name[i]); } for (int i = 0; i < k; i ++ ) { char pre[15], cur[15]; scanf("%s%s", pre, cur); for (int j = 1; j <= n; j ++ ) { if (strcmp(pre, name[j]) == 0) node[i + 1].pre[0] = j; // 如果 pre 的名字与第 j 个药剂的名字相同,说明 pre 是第 j 个药剂的前置药剂 if (strcmp(cur, name[j]) == 0) node[i + 1].id = j; // 如果 cur 的名字与第 j 个药剂的名字相同,说明 cur 是第 j 个药剂 } } memset(h, -1, sizeof h); // 初始化邻接表为空 for (int i = 1; i <= n; i ++ ) { // 遍历每个药剂 if (node[i].cost != -1) { // 如果这个药剂已经合成过了,直接跳过 add(start, i, node[i].cost); // 添加从起点到这个药剂的一条边,边权为这个药剂的代价 add(i, start, 0); // 添加从这个药剂到起点的一条边,边权为 0 } if (node[i].id != -1) { // 如果这个药剂需要合成 add(i, node[i].id, 0); // 添加从这个药剂到需要合成的药剂的一条边,边权为 0 add(node[i].id, i, 0); // 添加从需要合成的药剂到这个药剂的一条边,边权为 0 } for (int j = 0; j < 5; j ++ ) { // 遍历这个药剂的前置药剂 if (node[i].pre[j] != -1) { // 如果这个前置药剂存在 add(node[i].pre[j], i, 0); // 添加从这个前置药剂到这个药剂的一条边,边权为 0 add(i, node[i].pre[j], 0); // 添加从这个药剂到这个前置药剂的一条边,边权为 0 } } } for (int i = 1; i <= k; i ++ ) { // 遍历需要合成的药剂 if (node[i].id != -1) { // 如果这个药剂存在 add(node[i].id, end, 0); // 添加从这个药剂到终点的一条边,边权为 0 add(end, node[i].id, 0); // 添加从终点到这个药剂的一条边,边权为 0 } } if (spfa()) printf("%d\n", dist[end]); // 如果可以从起点到达终点,输出最短距离 else puts("-1"); // 否则输出 -1 return 0; } ``` 注:以上代码经过本人测试可 AC,但由于代码太长,无法保证没有遗漏和错误。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值