【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)


【代码实践】focal loss及其变形原理详细讲解和图像分割实践(含源码)

Focal Loss 是一种用于解决类别不平衡问题的损失函数,在目标检测和语义分割等任务中得到了广泛应用。传统的交叉熵损失函数(Cross-entropy Loss)在处理类别不平衡问题时,容易受到多数类别的影响,导致模型对少数类别的分类效果较差。Focal Loss 通过引入一个可调参数 γ \gamma γ ,对少数类别的错误分类进行更加强烈的惩罚,从而提高模型对少数类别的分类能力。

1.Focal Loss定义

Focal Loss公式:
L f o c a l − l o s s = ( 1 − p t ) γ ⋅ log ⁡ ( p t ) L_{focal-loss} = (1 - p_t)^\gamma \cdot \log(p_t) Lfocalloss=(1pt)γlog(pt)
其中:
{ p t = p , y = 1 p t = 1 − p , o t h e r w i s e \left\{ \begin{aligned} \quad p_t &= p &,& y = 1\\ \quad p_t &= 1 - p&,&otherwise \end{aligned} \right. {ptpt=p=1p,,y=1otherwise

其中, p t p^t pt 表示模型预测为正类别的概率, γ \gamma γ为平衡系数,用于调整正类别和负类别的平
衡。 γ = 0 \gamma=0 γ=0时,该focal loss损失函数就退化为普通的交叉熵损失函数,如下面公式所示:
L c e = L ( y , p ) = − y log ⁡ ( p ) − ( 1 − y ) log ⁡ ( 1 − p ) L_{ce} = L(y, p) = -y \log(p) - (1 - y) \log(1 - p) Lce=L(y,p)=ylog(p)(1y)log(1p)

2.带权重的交叉熵损失函数

公式如下:
L w c e = 1 N ( ∑ y i = 1 m ( − α log ⁡ p ) + ∑ y i = 0 n − ( 1 − α ) l o g ( 1 − p ) ) L_{wce} = \frac{1}{N} \left( \sum_{y_i=1}^{m} (-\alpha \log p) + \sum_{y_i=0}^{n} -(1-\alpha)log(1-p) \right) Lwce=N1(yi=1m(αlogp)+yi=0n(1α)log(1p))
其中:
α 1 − α = n m \frac{\alpha}{1-\alpha} = \frac{n}{m} 1αα=mn
权重的大小根据正负样本的分布进行设置。

3.带权重的Focal Loss

3.1公式定义以及函数图像

受到带权重的交叉熵损失函数(章节2)的启发,则产生了带权重的 focal loss,将参数 α \alpha α引入focal loss 中,起到了对正负样本更强的平衡作用,函数定义如下公式:
L w f l = { − ( 1 − α ) p t γ log ⁡ ( 1 − p t ) 当  y = 0 − α ( 1 − p t ) γ log ⁡ ( p t ) 当  y = 1 L_{wfl} = \begin{cases} -(1 - \alpha) p_t^\gamma \log(1 - p_t) & \text{当 } y = 0 \\ -\alpha (1 - p_t)^\gamma \log(p_t) & \text{当 } y = 1 \end{cases} Lwfl={(1α)ptγlog(1pt)α(1pt)γlog(pt) y=0 y=1
在本文中采用 α = 0.25 , γ = 2 \alpha=0.25, \gamma=2 α=0.25,γ=2 作为参数值进行该损失函数设计,函数图像如下:
在这里插入图片描述

3.2原理解释(为什么能平衡正负样本)

观察图像 y = 0 y=0 y=0,这个曲线表示表示预测错误的情况,由函数图像可见,在预测错误的情况下,随着预测概率的提高,其对应的函数值越大,损失值也相应地增大,因此在反向传播时,训练错误的情况在全部loss中占据更大的比例,能更多地进行反向传播,使得模型训练更专注在负样本上。 y = 1 y=1 y=1的情况则与之相反。

4.代码编写

4.1 二分类focal loss

#适用于二分类的focal loss
class BinaryFocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2): # 定义alpha和gamma变量
    super(BinaryFocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

# 前向传播
def forward(self, preds, labels):
    eps = 1e-7  # 防止数值超出定义域
    # 开始计算
    loss_y1 = -1 * self.alpha * \
        torch.pow((1 - preds), self.gamma) * \
        torch.log(preds + eps) * labels
    loss_y0 = -1 * (1 - self.alpha) * torch.pow(preds,
                                                self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
    loss = loss_y0 + loss_y1
    return torch.mean(loss)

4.2 多分类focal loss

利用二分类的focal loss即可顺利写出多分类的focal loss如下:

# 多分类focal loss
class MultiFocalLoss(nn.Module):
def __init__(self):
    super(MultiFocalLoss, self).__init__()
    
# 前向传播,注意我们在计算损失函数时,比如在图像分割任务中,我们需要
# 使用one-hot编码将多分类任务转为多个二分类任务进行计算。
def forward(self, preds, labels):
    total_loss = 0
    # 使用了二分类的focal loss
    binary_focal_loss = BinaryFocalLoss()
    logits = F.softmax(preds, dim=1)
    # 这里shape时[B,C,W,H],通道一就是class num
    nums = labels.shape[1]
    for i in range(nums):
        loss = binary_focal_loss(logits[:, i], labels[:, i])
        total_loss += loss
    return total_loss / nums

4.3 pytorch具体使用示例

def train(args):
    model = TransRes1Unet(1, 10).to(args.device) # 初始化自己模型
    batch_size = args.batch_size # 初始化batch-size
    criterion = my_loss.MultiFocalLoss()  # 初始化这里定义的focal loss
    optimizer = optim.Adam(model.parameters(), lr=0.001)   # 初始化优化器
    # 初始化自己的数据集
    ms_dataset = MSDataset(
        args.train_data_folder, transform=x_transforms, target_transform=y_transforms)
    # 构建dataloader
    dataloaders = DataLoader(
        ms_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    # 开始训练
    train_model(args, model, criterion, optimizer, dataloaders)

可以看到,在训练代码一开始初始化时使用了focal loss作为损失函数参与模型训练。

for x, y in dataload:
        with torch.autocast(device_type='cuda', dtype=torch.float32):
            step += 1
            inputs = x.to(args.device)
            labels = y.to(args.device)
            # forward
            outputs = model(inputs)
            # 这里便使用了初始化好的loss函数来计算loss
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()   # 反向传播
        scaler.step(optimizer=optimizer)
        scaler.update()
        # zero the parameter gradients
        optimizer.zero_grad()

5. 应用focal loss的分割效果

在这里插入图片描述
第一行为手工标注的真值,第二行为使用focal loss的unet变形预测结果,第三行为原始图像

可以看出来,使用纯focal loss来进行图像分割任务比起其他混合型loss来说效果依然不错,在大目标的分割上效果较好小目标也依然能分割出来

美中不足的时可能过于关注细节,导致最大的目标分割出现了残缺,这个需要引入其他loss函数进行调节,不在本篇文章讨论范围。



关注博主,收获更多干货


需要全部源代码或者更多支持的请关注点赞收藏博主并在评论区评论噢!


往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】dice loss及其变形原理详细讲解和图像分割实践(含源码)

  • 17
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
Focal Loss是一种用于解决类别不平衡问题的损失函数,特别适用于目标检测和图像分割任务。它由Lin等人在2017年提出。 传统的交叉熵损失函数在处理类别不平衡问题时存在一些问题,例如在一个大多数为负样本的数据集中,模型可能倾向于预测为负样本,导致正样本的预测效果较差。Focal Loss通过引入一个可调参数来解决这个问题。 Focal Loss的核心思想是减少易分类样本的权重,以便模型更加关注困难样本。它通过引入一个平衡因子(1-π)^γ,其中π表示预测概率,γ是一个可调参数。 具体来说,Focal Loss的计算公式如下: FL = -α(1-π)^γ * log(π) 其中,α是一个平衡因子,用于调整正负样本的权重。当α接近0时,负样本的权重增大;当α接近1时,正负样本的权重相等。 通过引入(1-π)^γ,Focal Loss可以减少易分类样本的权重,让模型更加关注困难样本。当样本的预测概率π接近1时,(1-π)^γ趋近于0,损失函数的权重减少;当样本的预测概率π接近0时,(1-π)^γ趋近于1,损失函数的权重增加。 通过调整参数γ,可以控制Focal Loss对于易分类样本的关注程度。较大的γ会增加对易分类样本的关注,而较小的γ会减少对易分类样本的关注。 总结来说,Focal Loss通过减少易分类样本的权重,使得模型更加关注困难样本,从而提高模型在类别不平衡问题上的性能。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WanHeng WyattVan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值