文章目录
1、 环境
1.1 paddlepaddle版本
paddlepaddle-gpu 2.2.2
1.2 PaddleClas版本
* release/2.4
2、 CS_CE Loss
CS_CE Loss在CEloss上加入了对于类别数量的权重系数,增加模型对于样本数量少类别的回归能力
计算方式:
L
o
s
s
(
z
,
c
)
=
−
(
N
m
i
n
N
c
)
γ
∗
C
r
o
s
s
E
n
t
r
o
p
y
(
z
,
c
)
Loss(z, c) = - (\frac{N_{min}}{N_c})^\gamma * CrossEntropy(z, c)
Loss(z,c)=−(NcNmin)γ∗CrossEntropy(z,c)
其中:
- γ \gamma γ为控制权重的超参数
- N m i n N_{min} Nmin为存在样本最少的类别的样本数量
- N c N_c Nc为c类别样本数量
3、paddle代码
import paddle.nn as nn
import paddle
import numpy as np
import paddle.nn.functional as F
class CostSensitiveCE(nn.Layer):
r"""
Equation: Loss(z, c) = - (\frac{N_min}{N_c})^\gamma * CrossEntropy(z, c),
where gamma is a hyper-parameter to control the weights,
N_min is the number of images in the smallest class,
and N_c is the number of images in the class c.
The representative re-weighting methods, which assigns class-dependent weights to the loss function
Args:
gamma (float or double): to control the loss weights: (N_min/N_i)^gamma
"""
def __init__(self, num_class_list, gamma):
super(CostSensitiveCE, self).__init__()
self.num_class_list = num_class_list
self.csce_weight = paddle.to_tensor(np.array([(min(self.num_class_list) / N)**gamma for N in self.num_class_list], dtype=np.float32))
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
cs_ce_loss = F.cross_entropy(x, label=label, soft_label=soft_label, weight=self.csce_weight)
cs_ce_loss = cs_ce_loss.mean()
return {"CS_CELoss":cs_ce_loss}
4、并入PaddleClas
4.1 loss代码
在PaddeClas/ppcls/loss
目录下添加文件cs_celoss.py
文件,并在PaddeClas/ppcls/loss/__init__.py
中添加
from .cs_celoss import CostSensitiveCE
4.2 config 文件
Loss:
Train:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1 # 一个Classfier默认为1
Eval:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1
5、参考文献
- 论文:Bag of Tricks for Long-Tailed Visual Recognition with Deep Convolutional Neural Networks
- 官方代码:https://github.com/zhangyongshun/BagofTricks-LT
本帖写于:2022年8月6号, 未经本人允许,禁止转载。