学习目标:
· 搞清楚损失函数的运作流程(基于pytorch)
· 手写个Focalloss
· 搞清楚底层nn.Module原理
损失函数类:
1、 搭建损失函数类,继承nn.Module
2、 重写forward方法, forward方法输入即为网络预测值与真实标签
3、 返回即损失的值
4、 里面的运输都是张量的运算并且都在cuda环境下,注意有时候申请变量会是在cpu环境而引发错误,forward过程即损失前馈过程,对于如何方向传播,哪些变量是会自动求导的,需要后续看看nn.Module源码。
基于one-hot的Focalloss(3d):
# 只有正样本和负样本
class Focal(nn.Module):
def __init__(self, gamma=2, alpha=0.25, size_average=True):
super(Focal, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.size_average = size_average
def forward(self, input, target, weight=None):
assert input.size()==target.size(),'class FocalLoss:input与target不一样大小'
# 将预测值与目标真值都转换为(N*c*z*x*y,1)
pred = input.contiguous().view(-1,1)
target = target.contiguous().view(-1,1)
# 将预测样本的正副概率都计算出来,(pred,2)
pred = torch.cat((1-pred,pred),dim=1).cuda()
#根据target生成mask,将真值正负两种情况对应前面的pred的提取出来
mask = torch.zeros(pred.size()).cuda()
mask.scatter_(1,target.view(-1,1).long(), 1)
# 利用 mask 将所需概率值挑选出来
probs = (pred * mask).sum(dim=1).view(-1, 1)
probs = probs.clamp(min=0.0001, max=1.0)
# 计算概率的 log 值
log_p = probs.log()
# 代入公式计算
alpha = torch.ones(pred.size()).cuda()
alpha[:, 0] = alpha[:, 0] * (1 - self.alpha)
alpha[:, 1] = alpha[:, 1] * self.alpha
alpha = (alpha * mask).sum(dim=1).view(-1, 1)
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25, size_average=True):
super(FocalLoss, self).__init__()
self.focal = Focal(gamma,alpha,size_average)
def forward(self, input, target, weight=None):
# 自己的网络是onehot编码的
soft = nn.Softmax(dim=1)
input = soft(input)
C = input.size(1)
# 循环计算每个label的loss
loss_all = 0
for i in range(C):
#重新定义每层的目标标签为1
mask = torch.zeros(target.size()).cuda()
mask[target == i] = 1
pred = input[:,i,:,:,:]
target1 = mask[:,0,:,:,:]
loss = self.focal(pred, target1)
loss_all += loss
return loss_all/C
理解nn.Module:
损失函数的本质也就是“对输入进行函数运算,得到一个输出”
继承nn.Module相当于自己实现损失函数的前馈过程,而求导以及反馈过程是nn.Module中帮忙实现的。也就是pytorch将“模块、层、激活函数、损失函数”这些概念统一到了一起。
但是nn.Module是如何或者在哪将优化器,求导,反向传播联系到一起的呢。。。。。