深度学习| Pytorch实现DiseLoss代码

基础

Dice的计算在图像分割中基本不可避免会使用,用来作为评价指标。

Dice可以计算集合的相似程度,取值范围在[0,1],公式如下所示:
D i c e ( X , Y ) = 2 ∗ ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice(X,Y)=\frac{2*|X\cap Y|}{|X|+|Y|} Dice(X,Y)=X+Y2XY

Dice Loss表达式:
1 − D i c e ( X , Y ) = 1 − 2 ∗ ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ 1-Dice(X,Y)=1-\frac{2*|X\cap Y|}{|X|+|Y|} 1Dice(X,Y)=1X+Y2XY

手推

对于图像分割来说,Dice Loss的计算中,X看作是Label(标签)像素点的集合,Y看作是Prediction(预测)像素点的集合。

举个具体的例子:
X是ground truth中0的背景,1是前景;Y是预测的mask。
X = [ 0 0 0 1 1 1 1 1 1 ] X=\begin{bmatrix} 0&0&0\\ 1&1&1\\ 1&1&1\\ \end{bmatrix} X= 011011011
Y = [ 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 ] Y=\begin{bmatrix} 0.01&0.02&0.03\\ 0.04&0.05&0.06\\ 0.07&0.08&0.09\\ \end{bmatrix} Y= 0.010.040.070.020.050.080.030.060.09
计算 X ⋂ Y X\bigcap Y XY:
X ⋂ Y = [ 0 0 0 1 1 1 1 1 1 ] ∗ [ 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 ] = [ 0 0 0 0.04 0.05 0.06 0.07 0.08 0.09 ] X\bigcap Y=\begin{bmatrix} 0&0&0\\ 1&1&1\\ 1&1&1\\ \end{bmatrix}* \begin{bmatrix} 0.01&0.02&0.03\\ 0.04&0.05&0.06\\ 0.07&0.08&0.09\\ \end{bmatrix}=\begin{bmatrix} 0&0&0\\ 0.04&0.05&0.06\\ 0.07&0.08&0.09\\ \end{bmatrix} XY= 011011011 0.010.040.070.020.050.080.030.060.09 = 00.040.0700.050.0800.060.09
计算 ∣ X ∣ 和 ∣ Y ∣ ,所有元素相加就行 |X|和|Y|,所有元素相加就行 XY,所有元素相加就行
∣ X ∣ = 6 |X|=6 X=6
∣ Y ∣ = 0.45 |Y|=0.45 Y=0.45
这么一推发现,在计算 ∣ X ⋂ Y ∣ |X\bigcap Y| XY就会把背景的预测给mask掉,从而能够更考虑前景,在前景占比比较小的时候,可以添加Dice Loss来防止收敛到占比更大的背景上,防止忽略前景分割的准确性。

代码

完整代码

class DiceLoss(_Loss):
    def forward(self, output, target, ignore_index=None):
        """
            output : NxCxHxW Variable
            target :  NxHxW LongTensor
            ignore_index : int index to ignore from loss
        """
        eps = 0.0001

        output = output.exp()# 计算指数
        encoded_target = output.detach() * 0# detach()返回与旧的tensor共享内存,不参加梯度计算
        if ignore_index is not None:
            mask = target == ignore_index# 会输出target形状的bool类型张量,和ignore_index相等的是True,不相等则是False
            target = target.clone()# Pytorch赋值指向都是同一个内存地址,clone可以实现深拷贝
            target[mask] = 0# 会把target中和mask相同位置为True的值替换成O
            encoded_target.scatter_(1, target.unsqueeze(1), 1)# 多类别label转为one-hot类型的label
            mask = mask.unsqueeze(1).expand_as(encoded_target)# 把mask进行升为并扩展到和encoded_target一样的形状,也就是NxCxHxW
            encoded_target[mask] = 0# 由于一张图片每个像素点只能有一个label所以把one-hot类型所有label该点值都赋值为0也没关系
        else:
            encoded_target.scatter_(1, target.unsqueeze(1), 1)# 多类别label转为one-hot类型的label

        intersection = output * encoded_target# 求output和label交集
        numerator = 2 * intersection.sum(0).sum(1).sum(1)# 交集所有元素合的两倍
        denominator = output + encoded_target# 求output和label并集

        if ignore_index is not None:
            denominator[mask] = 0
        denominator = denominator.sum(0).sum(1).sum(1) + eps# 并集所有元素合的两倍,这里得到的是C大小的张量
        loss_per_channel = 1 - (numerator / denominator)# Dice Loss计算

        return loss_per_channel.sum() / output.size(1)# 总Dice Loss/label类数

代码解释补充

scatter函数的理解
scatter(dim, index, src)将src中数据根据index中的索引按照dim的方向进行填充。
scatter不会修改原来的tensor,scatter_则是会修改到原来的tensor。
二维张量:

self[ index[i][j] ] [j] = src[i][j]# if dim == 0
self [i] [ index[i][j] ] = src[i][j]# if dim == 1

三维张量:

self[ index[i][j][k]  ] [j][k] = src[i][j][k]# if dim == 0
self[i] [ index[i][j][k] ] [k] = src[i][j][k]# if dim == 1
self[i][j] [ index[i][j][k] ]  = src[i][j][k]# if dim == 2

回到DiseLoss应用场景中来:用scatter_的作用是什么?将返回多类别label转为one-hot类型的label。
output是NxCxHxW,其中N是图片数量、C是Label种数、H是图片高度、W是图片宽度,是每张图片不同类别在不同像素点上的预测值。
target是NxHxW,是每张图片在不同像素点的类别标签。

# encoded_target是全为0的NxCxHxW的张量
# target是NxHxW大小,每张图片在不同像素点的类别标签
# target.unsqueeze(1)对target进行升维度,变成Nx1xHxW形状的维度,因为scatter要求维度要相同
encoded_target.scatter_(1, target.unsqueeze(1), 1)
# 经过上述步骤后,encoded_target就变成了NxCxHxW,因为target中值是不同类别

sum函数的理解
sum(1):求数组每一行的和,纵向压缩。
sum(0):求数组每一列的和,横向压缩。
回到DiseLoss应用场景中来:denominator.sum(0).sum(1).sum(1),为什么要这么多次sum?四维张量的sum是怎么变化?
写个简单的四维张量,输出就知道了。

# 代码
x=torch.tensor([[[[ 0,  1,  2],[ 3,  4,  5],[ 6,  7, 8],[ 9,  10,  11]],[[ 0,  1,  2],[ 3,  4,  5],[ 6,  7, 8],[ 9,  10,  11]]]])
print(x.size())
print(x.sum(0))
print(x.sum(0).size())
print(x.sum(0).sum(1))
print(x.sum(0).sum(1).size())
print(x.sum(0).sum(1).sum(1))
print(x.sum(0).sum(1).sum(1).size())
# 输出
torch.Size([1, 2, 4, 3])
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],
         
        [[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]]])
torch.Size([2, 4, 3])
tensor([[18, 22, 26],
        [18, 22, 26]])
torch.Size([2, 3])
tensor([66, 66])
torch.Size([2])

denominator是NxCxHxW,通过上面的例子可以发现denominator.sum(0).sum(1).sum(1)得到的结果是每个种Label类别预测值的求和,是一个C大小的张量。

杂谈

开始深究深度学习的代码,发现python还挺麻烦的,主要是在读代码方面。Python虽然可以不用申明各个变量的类型,很方便,但很多时候光看代码都很难判断出该变量是个什么情况。虽然Python有很多方便的功能函数用法,使用起来很简洁,但是读代码的时候就要不停的搜,有时候写的很简洁的用法还很难搜出来,例如“mask = target == ignore_index”。

所以写代码尤其是函数的时候,可以注释一下张量的形状、类型。读代码的时候,除了上网搜索来理解,还可以多写个简单案例的代码来print一下,。

  • 28
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值