熵:代表一个系统的混乱程度
在图像分割中,rgb图像标签一般是单通道的,三维张量,shape=[batchsize,hight,width],假设该数据集共有nclasses类,网络预测输出为四维张量,shape=[batchsize,nclasses,hight,width]
交叉熵CrossEntropyLoss()
参考pytorch语义分割中CrossEntropyLoss()损失函数的理解与分析_nn.crossentropyloss 语义分割-CSDN博客
官方源码: CrossEntropyLoss — PyTorch 2.2 documentation
pytorch语义分割中CrossEntropy、FocalLoss和DiceLoss三类损失函数的理解与分析_dice loss震荡-CSDN博客
上述链接交叉熵损失图文+代码解释非常全面
深入理解这段话:
- 假设标签上第一个像素点上的标签为7,计算交叉熵时,就会只对应第七张预测特征图上的第一个像素点。
- 交叉熵的计算只是把标签作为索引值,去索引对应的特征图,然后找到特征图上相应位置的像素值,作为网络预测的概率值。
- 这也就是为什么当计算交叉熵时,填写的类别数与数据集的类别数不相等时会爆出标签越界的错误。比如说在VOC数据集中,需区分20类别,但是输入是21类,加上了背景信息,以防报错。(这段话有待实验验证)
最详细的语义分割---07交叉熵到底在干什么?_语义分割交叉熵-CSDN博客
公式理解:
- q(x)就是模型预测第n张特征图上的像素点属于该类别的概率,概率是小于0到1的,所以这里需要有softmax在通道数求概率(至于为什么在通道维度,因为通道代表着不同的类别),通道维度=1
- p(x)是标签,那么这个标签是真实标签图像上面的值吗,答案不是。这个标签值应该是所有类别的第几类,比如上面,标签值=7。因为前面讲了,我们是那这个标签去索引对应位置的像素点,然后计算交叉熵,既然索引到了就说明这个标签是真实的,有效的。
源码:
torch.nn.CrossEntropyLoss(weight=None, size_average=None,
ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
- weight (Tensor, optional): a Tensor of size C and floating point dtype.C是类别个数,为每一个类手动设置权重
- size_average (bool, optional):默认下,True,代表一个batch数据的平均损失;False是指batch数据损失和
- ignore_index (int, optional):指定某一类,该类将被忽略且不影响输入梯度。
- reduce (bool, optional):暂时不管
- reduction (str, optional):mean是指输出加权平均值,sum:输出之和
- label_smoothing (float, optional) – A float in [0.0, 1.0] 指定计算损失时的平滑量
import torch.nn as nn
import torch
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()
'''
观察target 输入类型
1.类别 index 索引
2.由softmax计算的类可能性
'''
import torch.nn as nn
loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float)
tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float)
print(loss_func(pre, tgt))
# tensor(1.1087)
loss_func_none = nn.CrossEntropyLoss(reduction="none")
loss_func_mean = nn.CrossEntropyLoss(reduction="mean")
loss_func_sum = nn.CrossEntropyLoss(reduction="sum")
pre = torch.tensor([[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt = torch.tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], dtype=torch.float)
print(loss_func_none(pre, tgt))
print(loss_func_mean(pre, tgt))
print(loss_func_sum(pre, tgt))
'''
tensor([1.1087, 0.9329, 1.0852, 0.9991])
tensor(1.0315)
tensor(4.1259)
'''
torch.nn.CrossEntropyLoss() 参数、计算过程以及及输入Tensor形状 - 知乎 (zhihu.com)
带权重的交叉熵
参考
关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数-CSDN博客
交叉熵ignore_index
loss = nn.CrossEntropyLoss(ignore_index=9)
忽视标签9,对应像素值不参与损失计算和梯度计算