ps:感谢Code_Mart的解答,肯定了思路,不过他也不确定是否可以在pytorch中那么写.事情这样模棱两可让我很烦躁决定深究一下.看到博客https://blog.csdn.net/qq_22210253/article/details/85229988对CrossEntropyLoss的实测决定二分类的上再实测一下理解.
在图片二分类时,输入m张图片,输出一个m*2的Tensor(跟我的模型输出一样)。实际输入3张图片,分二类,最后的输出是一个3*2的Tensor,举例如下:
>>> input=torch.randn(3,2)
>>> input
tensor([[ 0.0082, 1.2996],
[ 0.1396, 0.4143],
[-1.6190, 1.1246]])
假设第一列是neg类,第二列是pos类.
然后对每一行使用Softmax,这样可以得到每张图片的概率分布.
>>> sm=torch.nn.Softmax(dim=1)
>>> sm(input)
tensor([[0.2156, 0.7844],
[0.4318, 0.5682],
[0.0604, 0.9396]])
然后对Softmax的结果取自然对数:
>>> torch.log(sm(input))
tensor([[-1.5343, -0.2429],
[-0.8399, -0.5652],
[-2.8060, -0.0623]])
Softmax后的数值都在0~1之间,所以ln之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
假设我们现在Target是[1,0,1](第一张图片是pos,第二张是neg,第三张是pos)。第一行取第1个元素,第二行取第0个,第三行取第1个,去掉负号,结果是:[0.2429,0.8399,0.0623]。再求个均值,结果是:
>>> (0.2429+0.8399+0.0623)/3
0.3817
先用NLLLoss函数实验一下:
>>> loss=torch.nn.NLLLoss()
>>> target = torch.tensor([1,0,1])
>>> loss(torch.log(sm(input)),target)
tensor(0.3817)
再用CrossEntropyLoss实验一下
>>> celoss=torch.nn.CrossEntropyLoss()
>>> celoss(input,target)
tensor(0.3817)
果然如此,真的感谢堆排序宝宝作者给我这么直观的感觉.
现在,再次回过头看看昨天的博客里的代码好像有点问题了.先贴下昨天的
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, gamma=2,alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.ce = nn.CrossEntropyLoss()
self.alpha=alpha
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = self.alpha*(1 - p) ** self.gamma * logp * target + \
(1-self.alpha)*(p) ** self.gamma * logp * (1-target)
return loss.mean()
其中p = torch.exp(-logp)是不是有点问题呢.继续上面的例子:
>>> logp=celoss(input,target)
>>> logp
tensor(0.3817)
>>> torch.exp(-logp)
tensor(0.6827)
感觉根本不能对应回去啊,我理解不了.我就换一下极端一点的数字如下测试
>>> input2=torch.tensor([[-1.8,1.8],[1.3,-1.2],[-1.6,1.5]])
>>> sm(input2)
tensor([[0.0266, 0.9734],
[0.9241, 0.0759],
[0.0431, 0.9569]])
>>> logp=celoss(input2,target)
>>> torch.exp(-logp)
tensor(0.9513)
这样一看,虽然反不回去,但是总体还是能体现概率.0.9513接近(0.9734,0.9241,0.9569)三者均值附近,而0.6827也在(0.7844,0.4318,0.9396)三者均值附近.好像是预测越准确越接近均值.但一个值替代不了矩阵中6个值.