"""
predect : 预测输入为经过softmax的概率图(概率范围为[0-1])
target : 标签为分类数,例如有12个分类:[0,1,2,3,...,12]
"""
import numpy as np
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
def forward(self, predect, target):
if self.apply_nonlin != None:
predect = self.apply_nonlin(predect)
num_class = predect.shape[1]
if predect.dim() > 2:
predect = predect.view(predect.size(0), predect.size(1), -1)
predect = predect.permute(0, 2, 1).contiguous()
predect = predect.view(-1, predect.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
alpha = self.alpha
if alpha == None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
else:
raise TypeError("not support alpha type")
# 对标签进行one_hot编码
idx = target.long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
# 将one_hot编码后的标签数字(0,1)映射为(0.00001,0.99999)
if self.smooth != 0:
one_hot_key = torch.clamp(one_hot_key, self.smooth/(num_class-1), 1.0-self.smooth)
# 得到损失计算的各个参数
pt = (one_hot_key * predect).sum(dim=1) + self.smooth
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1-pt), gamma) * pt.log()
if self.size_average: # 除以像素的个数
loss = loss.mean()
else:
loss = loss.sum()
return loss
if __name__ == '__main__':
# img = torch.rand(1, 3, 4, 4)
img = torch.tensor(
[[[[0.2, 0.2, 0.3, 0.3],
[0.2, 0.2, 0.3, 0.3],
[0.2, 0.2, 0.3, 0.3],
[0.2, 0.2, 0.3, 0.3]],
[[0.8, 0.8, 0.7, 0.7],
[0.8, 0.8, 0.7, 0.7],
[0.8, 0.8, 0.7, 0.7],
[0.8, 0.8, 0.7, 0.7]]]]
)
target = torch.tensor([[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0]])
net = FocalLoss(
apply_nonlin=None,
alpha=[0.9, 0.1], # 对不同的类别附以不同的权重:
# 第零类预测为0.3,概率过低,所以加以大的权重0.9;
# 第一类预测为0.8,概率高, 所以加小的权重0.1
gamma=2,
balance_index=0,
smooth=1e-5,
size_average=True
)
out = net(img, target)
print(out)
focal loss
最新推荐文章于 2024-05-27 14:44:03 发布