CV语义分割,类别加权损失函数

作者: 头孢就酒的快乐神仙

转载地址:https://bbs.huaweicloud.com/forum/thread-146576-1-1.html

在复现High-resolution NetWork(HRNet)用于语义分割时,Cityscapes数据集不同类别的物体在计算损失时赋有不同的权重。

weights_list = [0.8373, 0.918, 0.866, 1.0345, 
                1.0166,0.9969, 0.9754, 1.0489,
                0.8786, 1.0023, 0.9539, 0.9843,
                1.1116, 0.9037, 1.0865, 1.0955,
                1.0865, 1.1529, 1.0507]

在PyTorch中提供torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)可用于实现不同类别的加权计算。

MindSpore的r1.1和r1.2版本并未提供类似功能的损失函数。可以用以下代码实现:

class CrossEntropyLossWithWeights(_Loss):
    def __init__(self, weights, num_classes=19, ignore_label=255):
        super(CrossEntropyLossWithWeights, self).__init__()
        self.weights = weights
        self.resize = F.ResizeBilinear(cfg.train.image_size)
        self.one_hot = P.OneHot(axis=-1)
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.cast = P.Cast()
        self.ce = nn.SoftmaxCrossEntropyWithLogits()
        self.zeros = F.Zeros()
        self.fill = F.Fill()
        self.equal = F.Equal()
        self.select = F.Select()
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.mul = P.Mul()
        self.argmax = P.Argmax(output_type=mstype.int32)
        self.sum = P.ReduceSum(False)
        self.div = P.RealDiv()
        self.transpose = P.Transpose()
        self.reshape = P.Reshape()

    def construct(self, logits, labels):
        logits = self.resize(logits)
        labels_int = self.cast(labels, mstype.int32)
        labels_int = self.reshape(labels_int, (-1,))
        logits_ = self.transpose(logits, (0, 2, 3, 1))  # (12, 1024, 2048, 19)
        logits_ = self.reshape(logits_, (-1, self.num_classes))
        labels_float = self.cast(labels_int, mstype.float32)
        weights = self.zeros(labels_float.shape, mstype.float32)
        for i in range(self.num_classes):
            fill_weight = self.fill(mstype.float32, labels_float.shape, self.weights[i])
            equal_ = self.equal(labels_float, i)
            weights = self.select(equal_, fill_weight, weights)
        one_hot_labels = self.one_hot(labels_int, self.num_classes, self.on_value, self.off_value)
        loss = self.ce(logits_, one_hot_labels)
        loss = self.mul(weights, loss)
        loss = self.div(self.sum(loss), self.sum(weights))

        return loss
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值