pytorch-semseg代码解读分割loss.py

其中关于contiguous()函数介绍在:

PyTorch中的contiguous

其中关于transpose()函数在stack overflow上有个问题讲的不错:

Tranpose讲解

关于input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)的详细解释:

比如假如原来input 是 2*3*10*10,也就是两张图,每张图三个通道,像素大小为10*10=100个,第一次transpose(1,2)之后,变成了2*10*3*10,其实这步没有什么实际意义,主要是为了下一步交换出通道做铺垫,但还是可以解释一下,这就相当于把结构变成了 2(张图)*10个(10个后面这样的像素)*(3条通道*10个像素),即有2本书*每本书有10页*每页有3*10个文字(比如有黄红蓝3种颜色的文字,每种有10个字),接着transpose变为2*10*10*3,即2本书*每本书有10页*每页有10*3个文字(好比原来每种颜色文字有一行10个字,3个颜色有3行;现在变成了10行,每一个文字有3个颜色),画个图方便理解:

其实就是将input变成了一个单个像素上含有C个通道的值的形式,最后拉成h*w*n个值,方便计算。

import torch
import torch.nn.functional as F


def cross_entropy2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    nt, ht, wt = target.size()

    # Handle inconsistent size between input and target
    if h != ht and w != wt:  # 如果输入和目标大小不一致,上采样
        input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)

    # 其实就是将input变成了一个像素上含有C个通道的值的形式,最后拉成h*w*n个值
    input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)# GT是单通道,因此拉成一维就可以了
    loss = F.cross_entropy(
        input, target, weight=weight, size_average=size_average, ignore_index=250
    )# weight是各类别的权重
    return loss


def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
    if not isinstance(input, tuple):
        return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average)

    # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16]
    if scale_weight is None:  # scale_weight: torch tensor type
        n_inp = len(input)
        scale = 0.4
        scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to(
            target.device
        )

    loss = 0.0
    for i, inp in enumerate(input):
        loss = loss + scale_weight[i] * cross_entropy2d(
            input=inp, target=target, weight=weight, size_average=size_average
        )

    return loss


def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):

    batch_size = input.size()[0]

    def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):

        n, c, h, w = input.size()
        input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
        target = target.view(-1)
        loss = F.cross_entropy(
            input, target, weight=weight, reduce=False, size_average=False, ignore_index=250
        )

        topk_loss, _ = loss.topk(K)# 得到前k个最大的loss
        reduced_topk_loss = topk_loss.sum() / K # 求均值,即1个batch的loss

        return reduced_topk_loss

    loss = 0.0
    # Bootstrap from each image not entire batch
    for i in range(batch_size):
        loss += _bootstrap_xentropy_single(
            input=torch.unsqueeze(input[i], 0),
            target=torch.unsqueeze(target[i], 0),
            K=K,
            weight=weight,
            size_average=size_average,
        )
    return loss / float(batch_size)# 求整个的loss

 

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值