多分类focal loss代码实现,基于正常loss函数获取多分类概率后,用文中几行代码公式,计算得到新的loss


from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes=5, size_average=True):

        super(focal_loss, self).__init__()
        self.size_average = size_average
        if isinstance(alpha, (float, int)):    #仅仅设置第一类别的权重
            print("33afsd33")
            self.alpha = torch.zeros(num_classes)
            print(self.alpha, alpha)
            self.alpha[0] += alpha
            print(self.alpha[0])
            self.alpha[1:] += (1 - alpha)
            print(self.alpha[1:])
        if isinstance(alpha, list):  #全部权重自己设置
            print("333d3")
            self.alpha = torch.Tensor(alpha)
        self.gamma = gamma


    def forward(self, inputs, targets):
        alpha = self.alpha
        N = inputs.size(0)
        C = inputs.size(1)
        # 下面这些只是为了获取四个样本的概率probs
        P = F.softmax(inputs,dim=1)
        # ---------one hot start--------------#
        class_mask = inputs.data.new(N, C).fill_(0)  # 生成和input一样shape的tensor
        class_mask = class_mask.requires_grad_()  # 需要更新, 所以加入梯度计算
        ids = targets.view(-1, 1)  # 取得目标的索引
        class_mask.data.scatter_(1, ids.data, 1.)  # 利用scatter将索引丢给mask
        # ---------one hot end-------------------#
        probs = (P * class_mask).sum(1).view(-1, 1)
        print('留下targets的概率(1的部分),0的部分消除\n', probs)
        # 将softmax * one_hot 格式,0的部分被消除 留下1的概率, shape = (5, 1), 5就是每个target的概率
        #



        # 上面那些不需要管,重点看下面的focal loss公式;其实魔改自己多分类的,就是这里加上
        log_p = probs.log()
        # 取得对数
        print("1 - probs",1 - probs)
        loss = torch.pow((1 - probs), self.gamma) * log_p
        print("loss", loss)
        batch_loss = -alpha *loss.t()  # 對應下面公式
        print('每一个batch的loss\n', batch_loss)
        # batch_loss就是取每一个batch的loss值

        # 最终将每一个batch的loss加总后平均
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        print('loss值为\n', loss)
        return loss

#多分类 五类数据,第一类少样本数据,a= 0.25,其他都是0.75;

torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
# print('input值为\n', input)
targets = torch.randint(5, (5, ))
print('targets值为\n', targets)

criterion = focal_loss()
loss = criterion(input, targets)
loss.backward()
# 针对多分类任务的 CELoss 和 Focal Loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class CELoss(nn.Module):
    def __init__(self, class_num, alpha=None, use_alpha=False, size_average=True):
        super(CELoss, self).__init__()
        self.class_num = class_num
        self.alpha = alpha
        if use_alpha:
            self.alpha = torch.tensor(alpha).cuda()

        self.softmax = nn.Softmax(dim=1)
        self.use_alpha = use_alpha
        self.size_average = size_average

    def forward(self, pred, target):
        prob = self.softmax(pred.view(-1,self.class_num))
        prob = prob.clamp(min=0.0001,max=1.0)

        target_ = torch.zeros(target.size(0),self.class_num).cuda()
        target_.scatter_(1, target.view(-1, 1).long(), 1.)

        if self.use_alpha:
            batch_loss = - self.alpha.double() * prob.log().double() * target_.double()
        else:
            batch_loss = - prob.log().double() * target_.double()

        batch_loss = batch_loss.sum(dim=1)

        # print(prob[0],target[0],target_[0],batch_loss[0])
        # print('--')

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, use_alpha=False, size_average=True):
        super(FocalLoss, self).__init__()
        self.class_num = class_num
        self.alpha = alpha
        self.gamma = gamma
        if use_alpha:
            self.alpha = torch.tensor(alpha).cuda()

        self.softmax = nn.Softmax(dim=1)
        self.use_alpha = use_alpha
        self.size_average = size_average

    def forward(self, pred, target):
        prob = self.softmax(pred.view(-1,self.class_num))
        prob = prob.clamp(min=0.0001,max=1.0)

        target_ = torch.zeros(target.size(0),self.class_num).cuda()
        target_.scatter_(1, target.view(-1, 1).long(), 1.)

        if self.use_alpha:
            batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
        else:
            batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()

        batch_loss = batch_loss.sum(dim=1)

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss


torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
# print('input值为\n', input)
targets = torch.randint(5, (5, ))
print('targets值为\n', targets)

criterion = FocalLoss()
loss = criterion(input, targets)
loss.backward()


FL中伽马等于0,就是CE

 

  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值