Focal loss二分类和多分类一定要分开写,揉在一起会很麻烦。
Tensorflow 实现:
import
Pytorch 实现:
multi class
import torch
# Pytorch
class Focal_loss(torch.nn.Module):
def __init__(self, alpha=None, gamma=0, OHEM_percent=0.6, smooth_eps=0, class_num=2, size_average=True):
super(Focal_loss, self). __init__()
self.gamma = gamma
self.alpha = alpha
self.OHEM_percent = OHEM_percent
self.smooth_eps = smooth_eps
self.class_num = class_num
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.size_average = size_average
def forward(self, logits, label):
# logits:[b,c,h,w] label:[b,c,h,w]
pred = logits.softmax(dim=1)
if pred.dim() > 2:
pred = pred.view(pred.size(0),pred.size(1),-1) # b,c,h,w => b,c,h*w
pred = pred.transpose(1,2) # b,c,h*w => b,h*w,c
pred = pred.contiguous().view(-1,pred.size(2)) # b,h*w,c => b*h*w,c
label = label.argmax(dim=1)
label = label.view(-1,1) # b*h*w,1
if self.alpha:
self.alpha = self.alpha.type_as(pred.data)
alpha_t = self.alpha.gather(0, label.view(-1)) # b*h*w
pt = pred.gather(1, label).view(-1) # b*h*w
diff = (1-pt) ** self.gamma
FL = -1 * alpha_t * diff * pt.log()
OHEM = FL.topk(k=int(self.OHEM_percent * FL.size(0)), dim=0)
if self.smooth_eps > 0:
K = 16
lce = -1 * torch.sum(pred.log(), dim=1) / K
loss = (1-self.eps) * FL + self.eps * lce
if size_average: return loss.mean() # or OHEM.mean()
else: return loss.sum() # + OHEM.sum()
二分类
import torch
# 二分类
class Focal_loss(torch.nn.Module):
def __init__(self, alpha=None, gamma=0, size_average=True):
super(Focal_loss, self). __init__()
self.gamma = gamma
self.alpha = alpha
self.size_average = size_average
def forward(self, logits, label):
# logits:[b,h,w] label:[b,h,w]
pred = logits.sigmoid()
pred = pred.view(-1) # b*h*w
label = label.view(-1)
if self.alpha:
self.alpha = self.alpha.type_as(pred.data)
alpha_t = self.alpha * label + (1 - self.alpha) * (1 - label) # b*h*w
pt = pred * label + (1 - pred) * (1-label)
diff = (1-pt) ** self.gamma
FL = -1 * alpha_t * diff * pt.log()
if size_average: return FL.mean()
else: return FL.sum()
参数应该不用多讲,看名字就知道。