目录
https://www.cnblogs.com/Fish0403/p/17073047.html
卷积神经网络系列之softmax,softmax loss和cross entropy的讲解_softmax层_AI之路的博客-CSDN博客
发现对一些概念很模糊,查阅资料再读
1. softmax
softmax用于分类问题
见下图,假设是T分类问题: backbone------logits------prob
2. 交叉熵损失
- 单分类
二分类或者多分类,都是如下公式
pred: [0.9, 0.1, 0] target:[1, 0, 0]
loss = -1*log(0.9) - 0×log(0.1) - 0×log(0) ≈ 0.1
- 多分类
多标签分类任务,即一个样本可以有多个标签,比如一张图片中同时含有猫和狗。
求每一类的损失,再相加
pred: [0.95, 0.73, 0.05] target: [1, 1, 0]
loss狗=-1×log(0.95)-(1-1)×log(1-0.95)≈0.05
loss猫=-1×log(0.73)-(1-1)×log(1-0.73)≈0.31
loss猪=-0×log(0.05)-(1-0)×log(1-0.05)≈0.05
loss总=loss狗+loss猫+loss猪=0.05+0.31+0.05=0.41
3. 代码实现
# 一般而言两者shape不同,inputs.shape(B,num_cla) target.shape(B,)
#CLASS----------类
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,
reduce=None, reduction='mean', label_smoothing=0.0)
使用:
loss = torch.nn.CrossEntropyLoss()
lo = loss(input, target)
#FUNCTION-------函数
torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None,
ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
使用:
lo = torch.nn.functional.cross_entropy(input, target)
# 损失函数中的weight参数用于调节不同类别样本占比差异很大的现象
# 对于训练图像数量较少的类,给它更多的权重,在预测这些类时出错,会受到更多的惩罚。
# 对于具有大量图像的类,可以赋予它较小的权重。
softmax + log + NLLloss = crossEntropyLoss
- softmax公式如下:
保证概率的非负性,再将转换后的结果归一化处理,使得各预测结果的概率之和等于1。
输入x中存在特别大的xi时,exp(xi)会变得很大,导致出现上溢的情况。当输入x中每个元素都为特别小的负数时,分母会变得很小,超出精度范围时向下取0,导致下溢
-
log_softmax
log_softmax是指在softmax函数的基础上再进行一次log运算,解决上溢、下溢
- NLLloss
将input中与target对应的那个值拿出来,加个负号,再求均值
不用对label进行one_hot编码(换言之,input比target高一维)
举例说明:
import torch
import torch.nn.functional as F
# 1D
input = torch.Tensor([[2, 3, 1], [3, 7, 9]])
target = torch.tensor([1, 2])
loss = F.nll_loss(input, target)
#loss: tensor(-6.) -3-9 = -12 -12/2 = -6
# 2D
input = torch.Tensor([[[2, 3],
[1, 5]],
[[3, 7],
[1, 9]]])
target = torch.tensor([[1, 1],
[0, 0]])
loss = F.nll_loss(input, target)
# tensor(-4.) -3-5
上述代码1D解释:
input.shape = (B,num_cla) = (2,3) target.shape = (B, )=(2, )两个样本的类别索引是1和2
针对第一个样本损失:标签是类别1,所以损失 = -3;
针对第二个样本损失:标签是类别2,所以损失 = -9
求两样本平均(-3-9)/2 = -6