02 图像分割损失函数

熵:代表一个系统的混乱程度

在图像分割中,rgb图像标签一般是单通道的,三维张量,shape=[batchsize,hight,width],假设该数据集共有nclasses类,网络预测输出为四维张量,shape=[batchsize,nclasses,hight,width]

交叉熵CrossEntropyLoss()

参考pytorch语义分割中CrossEntropyLoss()损失函数的理解与分析_nn.crossentropyloss 语义分割-CSDN博客

 官方源码: CrossEntropyLoss — PyTorch 2.2 documentation

pytorch语义分割中CrossEntropy、FocalLoss和DiceLoss三类损失函数的理解与分析_dice loss震荡-CSDN博客

上述链接交叉熵损失图文+代码解释非常全面

深入理解这段话:

  1. 假设标签上第一个像素点上的标签为7,计算交叉熵时,就会只对应第七张预测特征图上的第一个像素点。
  2. 交叉熵的计算只是把标签作为索引值,去索引对应的特征图,然后找到特征图上相应位置的像素值,作为网络预测的概率值。
  3. 这也就是为什么当计算交叉熵时,填写的类别数与数据集的类别数不相等时会爆出标签越界的错误。比如说在VOC数据集中,需区分20类别,但是输入是21类,加上了背景信息,以防报错。(这段话有待实验验证)

最详细的语义分割---07交叉熵到底在干什么?_语义分割交叉熵-CSDN博客

公式理解:

  1. q(x)就是模型预测第n张特征图上的像素点属于该类别的概率,概率是小于0到1的,所以这里需要有softmax在通道数求概率(至于为什么在通道维度,因为通道代表着不同的类别),通道维度=1
  2. p(x)是标签,那么这个标签是真实标签图像上面的值吗,答案不是。这个标签值应该是所有类别的第几类,比如上面,标签值=7。因为前面讲了,我们是那这个标签去索引对应位置的像素点,然后计算交叉熵,既然索引到了就说明这个标签是真实的,有效的。

源码:

torch.nn.CrossEntropyLoss(weight=None, size_average=None, 
ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
  1. weight (Tensoroptional):  a Tensor of size C and floating point dtype.C是类别个数,为每一个类手动设置权重
  2. size_average (booloptional):默认下,True,代表一个batch数据的平均损失;False是指batch数据损失和
  3. ignore_index (intoptional):指定某一类,该类将被忽略且不影响输入梯度。
  4. reduce (booloptional):暂时不管
  5. reduction (stroptional):mean是指输出加权平均值,sum:输出之和
  6. label_smoothing (floatoptional) – A float in [0.0, 1.0] 指定计算损失时的平滑量
import torch.nn as nn
import torch

# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()


# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

'''
观察target 输入类型
1.类别 index 索引
2.由softmax计算的类可能性
'''
import torch.nn as nn
loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float)
tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float)
print(loss_func(pre, tgt))
# tensor(1.1087)

loss_func_none = nn.CrossEntropyLoss(reduction="none")
loss_func_mean = nn.CrossEntropyLoss(reduction="mean")
loss_func_sum = nn.CrossEntropyLoss(reduction="sum")
pre = torch.tensor([[0.8, 0.5, 0.2, 0.5],
                    [0.2, 0.9, 0.3, 0.2],
                    [0.4, 0.3, 0.7, 0.1],
                    [0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt = torch.tensor([[1, 0, 0, 0],
                    [0, 1, 0, 0],
                    [0, 0, 1, 0],
                    [0, 0, 0, 1]], dtype=torch.float)
print(loss_func_none(pre, tgt))
print(loss_func_mean(pre, tgt))
print(loss_func_sum(pre, tgt))

'''
tensor([1.1087, 0.9329, 1.0852, 0.9991])
tensor(1.0315)
tensor(4.1259)

'''

torch.nn.CrossEntropyLoss() 参数、计算过程以及及输入Tensor形状 - 知乎 (zhihu.com)

带权重的交叉熵

参考

加权交叉熵损失函数-CSDN博客

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数-CSDN博客

交叉熵ignore_index

loss = nn.CrossEntropyLoss(ignore_index=9) 

忽视标签9,对应像素值不参与损失计算和梯度计算

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值