【代码实践】focal loss及其变形原理详细讲解和图像分割实践(含源码)
Focal Loss 是一种用于解决类别不平衡问题的损失函数,在目标检测和语义分割等任务中得到了广泛应用。传统的交叉熵损失函数(Cross-entropy Loss)在处理类别不平衡问题时,容易受到多数类别的影响,导致模型对少数类别的分类效果较差。Focal Loss 通过引入一个可调参数 γ \gamma γ ,对少数类别的错误分类进行更加强烈的惩罚,从而提高模型对少数类别的分类能力。
文章目录
1.Focal Loss定义
Focal Loss公式:
L
f
o
c
a
l
−
l
o
s
s
=
(
1
−
p
t
)
γ
⋅
log
(
p
t
)
L_{focal-loss} = (1 - p_t)^\gamma \cdot \log(p_t)
Lfocal−loss=(1−pt)γ⋅log(pt)
其中:
{
p
t
=
p
,
y
=
1
p
t
=
1
−
p
,
o
t
h
e
r
w
i
s
e
\left\{ \begin{aligned} \quad p_t &= p &,& y = 1\\ \quad p_t &= 1 - p&,&otherwise \end{aligned} \right.
{ptpt=p=1−p,,y=1otherwise
其中,
p
t
p^t
pt 表示模型预测为正类别的概率,
γ
\gamma
γ为平衡系数,用于调整正类别和负类别的平
衡。当
γ
=
0
\gamma=0
γ=0时,该focal loss损失函数就退化为普通的交叉熵损失函数,如下面公式所示:
L
c
e
=
L
(
y
,
p
)
=
−
y
log
(
p
)
−
(
1
−
y
)
log
(
1
−
p
)
L_{ce} = L(y, p) = -y \log(p) - (1 - y) \log(1 - p)
Lce=L(y,p)=−ylog(p)−(1−y)log(1−p)
2.带权重的交叉熵损失函数
公式如下:
L
w
c
e
=
1
N
(
∑
y
i
=
1
m
(
−
α
log
p
)
+
∑
y
i
=
0
n
−
(
1
−
α
)
l
o
g
(
1
−
p
)
)
L_{wce} = \frac{1}{N} \left( \sum_{y_i=1}^{m} (-\alpha \log p) + \sum_{y_i=0}^{n} -(1-\alpha)log(1-p) \right)
Lwce=N1(yi=1∑m(−αlogp)+yi=0∑n−(1−α)log(1−p))
其中:
α
1
−
α
=
n
m
\frac{\alpha}{1-\alpha} = \frac{n}{m}
1−αα=mn
即权重的大小根据正负样本的分布进行设置。
3.带权重的Focal Loss
3.1公式定义以及函数图像
受到带权重的交叉熵损失函数(章节2)的启发,则产生了带权重的 focal loss,将参数
α
\alpha
α引入focal loss 中,起到了对正负样本更强的平衡作用,函数定义如下公式:
L
w
f
l
=
{
−
(
1
−
α
)
p
t
γ
log
(
1
−
p
t
)
当
y
=
0
−
α
(
1
−
p
t
)
γ
log
(
p
t
)
当
y
=
1
L_{wfl} = \begin{cases} -(1 - \alpha) p_t^\gamma \log(1 - p_t) & \text{当 } y = 0 \\ -\alpha (1 - p_t)^\gamma \log(p_t) & \text{当 } y = 1 \end{cases}
Lwfl={−(1−α)ptγlog(1−pt)−α(1−pt)γlog(pt)当 y=0当 y=1
在本文中采用
α
=
0.25
,
γ
=
2
\alpha=0.25, \gamma=2
α=0.25,γ=2 作为参数值进行该损失函数设计,函数图像如下:
3.2原理解释(为什么能平衡正负样本)
观察图像 y = 0 y=0 y=0,这个曲线表示表示预测错误的情况,由函数图像可见,在预测错误的情况下,随着预测概率的提高,其对应的函数值越大,损失值也相应地增大,因此在反向传播时,训练错误的情况在全部loss中占据更大的比例,能更多地进行反向传播,使得模型训练更专注在负样本上。 y = 1 y=1 y=1的情况则与之相反。
4.代码编写
4.1 二分类focal loss
#适用于二分类的focal loss
class BinaryFocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2): # 定义alpha和gamma变量
super(BinaryFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
# 前向传播
def forward(self, preds, labels):
eps = 1e-7 # 防止数值超出定义域
# 开始计算
loss_y1 = -1 * self.alpha * \
torch.pow((1 - preds), self.gamma) * \
torch.log(preds + eps) * labels
loss_y0 = -1 * (1 - self.alpha) * torch.pow(preds,
self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
loss = loss_y0 + loss_y1
return torch.mean(loss)
4.2 多分类focal loss
利用二分类的focal loss即可顺利写出多分类的focal loss如下:
# 多分类focal loss
class MultiFocalLoss(nn.Module):
def __init__(self):
super(MultiFocalLoss, self).__init__()
# 前向传播,注意我们在计算损失函数时,比如在图像分割任务中,我们需要
# 使用one-hot编码将多分类任务转为多个二分类任务进行计算。
def forward(self, preds, labels):
total_loss = 0
# 使用了二分类的focal loss
binary_focal_loss = BinaryFocalLoss()
logits = F.softmax(preds, dim=1)
# 这里shape时[B,C,W,H],通道一就是class num
nums = labels.shape[1]
for i in range(nums):
loss = binary_focal_loss(logits[:, i], labels[:, i])
total_loss += loss
return total_loss / nums
4.3 pytorch具体使用示例
def train(args):
model = TransRes1Unet(1, 10).to(args.device) # 初始化自己模型
batch_size = args.batch_size # 初始化batch-size
criterion = my_loss.MultiFocalLoss() # 初始化这里定义的focal loss
optimizer = optim.Adam(model.parameters(), lr=0.001) # 初始化优化器
# 初始化自己的数据集
ms_dataset = MSDataset(
args.train_data_folder, transform=x_transforms, target_transform=y_transforms)
# 构建dataloader
dataloaders = DataLoader(
ms_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 开始训练
train_model(args, model, criterion, optimizer, dataloaders)
可以看到,在训练代码一开始初始化时使用了focal loss作为损失函数参与模型训练。
for x, y in dataload:
with torch.autocast(device_type='cuda', dtype=torch.float32):
step += 1
inputs = x.to(args.device)
labels = y.to(args.device)
# forward
outputs = model(inputs)
# 这里便使用了初始化好的loss函数来计算loss
loss = criterion(outputs, labels)
scaler.scale(loss).backward() # 反向传播
scaler.step(optimizer=optimizer)
scaler.update()
# zero the parameter gradients
optimizer.zero_grad()
5. 应用focal loss的分割效果
第一行为手工标注的真值,第二行为使用focal loss的unet变形预测结果,第三行为原始图像。
可以看出来,使用纯focal loss来进行图像分割任务比起其他混合型loss来说效果依然不错,在大目标的分割上效果较好,小目标也依然能分割出来。
美中不足的时可能过于关注细节,导致最大的目标分割出现了残缺,这个需要引入其他loss函数进行调节,不在本篇文章讨论范围。
关注博主,收获更多干货
需要全部源代码或者更多支持的请关注点赞收藏博主并在评论区评论噢!
往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】dice loss及其变形原理详细讲解和图像分割实践(含源码)