【DL】FocalLoss的PyTorch实现

【DL】FocalLoss的PyTorch实现

此篇不介绍FocalLoss的原理,仅展示PyTorch实现FocalLoss的两种方式。个人认为相关原理已在文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》中讲得很清晰,故此篇不再介绍。

方式一

同时计算一个batch中所有样本关于FocalLoss的损失值(来自文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》,个人补充了一些注释):

import torch
from torch import nn
import random
class FocalLoss(nn.Module):
    """
    参考 https://github.com/lonePatient/TorchBlocks
    """

    def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha, device=device)
        else:
            self.alpha = alpha
        self.epsilon = epsilon
    
    '''
    batch中所有样本一起计算loss
    '''
    def forward(self, input, target):
        """
        Args:
            input: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        num_labels = input.size(-1) # 类别数量
        idx = target.view(-1, 1).long() # 行向量target变成列向量idx
        one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)
        one_hot_key = one_hot_key.scatter_(1, idx, 1) # one_hot_key矩阵中的每一行对应相应样本的标签one_hot向量,利用scatter_方法将样本的标签类别标记为1,其余位置为0
        one_hot_key[:, 0] = 0  # ignore 0 index. 此行需要视具体情况决定是否保留,如果标签中存在类别0(而不是直接从类别1开始),此行应当注释、不使用
        logits = torch.softmax(input, dim=-1)
        loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() # 计算FocalLoss
        loss = loss.sum(1)
        return loss.mean()

# 固定随机数种子,方便复现
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
    loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
    # 设置随机数种子
    setup_seed(20) 
    input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]
    target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]
    output = loss(input, target)
    # print(output)
    output.backward()

方式二

一个batch中逐个样本计算关于FocalLoss的损失值,将它们求平均,返回一个batch内所有样本的FocalLoss的平均值:

import torch
from torch import nn
import random
class FocalLoss(nn.Module):
    """
    参考 https://github.com/lonePatient/TorchBlocks
    """

    def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha, device=device)
        else:
            self.alpha = alpha
        self.epsilon = epsilon
	
    '''
    逐个样本计算loss
    '''    
	def forward(self, input, target):
        """
        Args:
            input: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        num_labels = input.size(-1) # 类别数量
        loss = []
        for i, sample in enumerate(input):
            one_hot_key = torch.zeros(1, num_labels, dtype=torch.float32, device=input.device)
            one_hot_key.scatter_(1, target[i].view(1, -1), 1)

            logits = torch.softmax(sample, dim=-1)
            loss_this_sample = - self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
            loss_this_sample = loss_this_sample.sum(1)
            if i == 0:
                loss = loss_this_sample
            else:
                loss = torch.cat((loss, loss_this_sample))

        return loss.mean()

# 固定随机数种子,方便复现
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
    loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
    # 设置随机数种子
    setup_seed(20) 
    input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]
    target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]
    output = loss(input, target)
    # print(output)
    output.backward()
### 实现和使用 Focal Loss Focal Loss 是一种用于解决类别不平衡问题的损失函数,在目标检测和其他分类任务中表现出色。该损失函数通过引入两个参数——`alpha` 和 `gamma` 来调整不同类别的权重以及减少简单样本对总损失的影响。 #### 定义 Focal Loss 函数 为了在 PyTorch 中定义 Focal Loss,可以创建一个新的 Python 类继承自 `_Loss` 或者直接编写一个计算损失的方法: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduction == 'sum': return torch.sum(F_loss) elif self.reduction == 'mean': return torch.mean(F_loss) else: return F_loss ``` 此代码片段展示了如何构建一个基于二元交叉熵(Binary Cross Entropy, BCE)的 Focal Loss 计算器[^1]。 对于多标签或多分类的情况,则可能需要稍微修改上述逻辑来适应具体的场景需求。例如,在处理单热编码(one-hot encoded)的目标向量时,应该先将其转换成概率分布再应用 focal loss 公式[^2]。 #### 使用 Focal Loss 进行训练 当已经实现Focal Loss 后,就可以像其他标准损失一样应用于模型训练过程中: ```python criterion = FocalLoss(alpha=0.25, gamma=2) for data in dataloader: images, labels = data['image'], data['label'] outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ``` 这段伪代码说明了怎样实例化并调用之前定义好的 Focal Loss 对象来进行反向传播更新网络权值的操作[^3]。 需要注意的是,实际操作中可能会遇到维度不匹配等问题;这时可以根据具体情况进行适当的数据预处理或张量形状变换以确保输入输出的一致性[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值