Focal Loss 分类问题 pytorch实现代码(续3)

ps:虽然无法用NLLLoss函数来实现.但好歹最后实现了自己的想法.现在再来测试下最后和最开始的Focal Loss如下:

import torch
import torch.nn as nn
 
#二分类
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha=alpha
    def forward(self, input, target):
        # input:size is M*2. M is the batch number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        return loss.mean()
import torch
import torch.nn as nn


class FocalLoss2(nn.Module):

    def __init__(self, gamma=0, alpha=1):
        super(FocalLoss2, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        loss = self.alpha * loss
        return loss.mean()

用代码去测试实例:

import torch
from loss import FocalLoss
from loss2 import FocalLoss2


input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557,  1.4724]])
target= torch.Tensor([1,0,1,1])

# input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496]])
# target= torch.Tensor([1,0])

# input=torch.Tensor([[ 0.0543,  0.5641]])
# target= torch.Tensor([1])



print(torch.softmax(input,dim=1))

criterion = FocalLoss(gamma=2,alpha=0.25)
criterion1 = FocalLoss2(gamma=2,alpha=0.25)
criterion2 = torch.nn.CrossEntropyLoss()


res = criterion(input, target)
print(res)
res1 = criterion1(input, target.long())
print(res1)
res2 = criterion2(input, target.long())
print(res2)



tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0049)
tensor(0.2966)

改变target为:target= torch.Tensor([0,1,0,0])

tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.5403)
tensor(0.2289)
tensor(1.5092)

发现Focal Loss和Focal Loss2趋势相类似,且数量级总体相差不大.其实Focal Loss大概可以用Focal Loss2表示,敢写出来放到github上应该没什么问题.这是我的理解,希望对你的理解有帮助.

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值