pytorch argmax_Focal loss tf+pytorch实现(带ohem和label smoothing)

Focal loss二分类和多分类一定要分开写,揉在一起会很麻烦。

Tensorflow 实现:

import 

Pytorch 实现:

multi class

import torch

# Pytorch
class Focal_loss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=0, OHEM_percent=0.6, smooth_eps=0, class_num=2, size_average=True):
        super(Focal_loss, self). __init__()
        self.gamma = gamma
        self.alpha = alpha
        self.OHEM_percent = OHEM_percent
        self.smooth_eps = smooth_eps
        self.class_num = class_num
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, logits, label):
        # logits:[b,c,h,w] label:[b,c,h,w]
        pred = logits.softmax(dim=1)
        if pred.dim() > 2:
            pred = pred.view(pred.size(0),pred.size(1),-1)   # b,c,h,w => b,c,h*w
            pred = pred.transpose(1,2)                       # b,c,h*w => b,h*w,c
            pred = pred.contiguous().view(-1,pred.size(2))   # b,h*w,c => b*h*w,c
            label = label.argmax(dim=1)
            label = label.view(-1,1) # b*h*w,1

        if self.alpha:
            self.alpha = self.alpha.type_as(pred.data)
            alpha_t = self.alpha.gather(0, label.view(-1)) # b*h*w
            
        pt = pred.gather(1, label).view(-1) # b*h*w
        diff = (1-pt) ** self.gamma

        FL = -1 * alpha_t * diff * pt.log()
        OHEM = FL.topk(k=int(self.OHEM_percent * FL.size(0)), dim=0)
        if self.smooth_eps > 0:
            K = 16
            lce = -1 * torch.sum(pred.log(), dim=1) / K
            loss = (1-self.eps) * FL + self.eps * lce
        
        if size_average: return loss.mean() # or OHEM.mean()
        else: return loss.sum() # + OHEM.sum()

二分类

import torch

# 二分类
class Focal_loss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=0, size_average=True):
        super(Focal_loss, self). __init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, logits, label):
        # logits:[b,h,w] label:[b,h,w]
        pred = logits.sigmoid()
        pred = pred.view(-1) # b*h*w
        label = label.view(-1)

        if self.alpha:
            self.alpha = self.alpha.type_as(pred.data)
            alpha_t = self.alpha * label + (1 - self.alpha) * (1 - label) # b*h*w
            
        pt = pred * label + (1 - pred) * (1-label)
        diff = (1-pt) ** self.gamma

        FL = -1 * alpha_t * diff * pt.log()
        
        if size_average: return FL.mean()
        else: return FL.sum()

参数应该不用多讲,看名字就知道。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值