mmsegmentation 添加L1Loss

本文介绍了如何在PyTorch的mmseg.models.losses模块中实现L1Loss,重点讲解了如何处理输入shape不匹配的问题并通过one-hot编码转换。此外,展示了如何通过@LOSSES.register_module注册损失函数并实现loss_name属性以进行组合和追踪。
摘要由CSDN通过智能技术生成

mmseg/models/losses/模块中添加L1Loss定义:

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES


@LOSSES.register_module()
class L1Loss(nn.Module):
    # TODO: weight
    def __init__(self, loss_name='loss_l1', **kwargs):
        super(L1Loss, self).__init__()
        self._loss_name = loss_name

    def forward(self, pred, target, weight=None, ignore_index=None):
   		# pred: (n,c,h,w)   target: (n,h,w)
        classes = pred.shape[1]
        size = list(target.shape)
        size.append(classes)  # (n,h,w,c)
        target_one_hot = target.view(-1)  # (n*h*w)
        ones = torch.sparse.torch.eye(classes).to(target_one_hot.device)
        ones = ones.index_select(0, target_one_hot)  # (n*h*w, classes)
        ones = ones.view(*size)  # (n,h,w,c)
        target_one_hot = ones.permute(0, 3, 1, 2)  # (n,c,h,w)
        loss = nn.L1Loss()(pred, target_one_hot)
        return loss

	@property
    def loss_name(self):
        """Loss Name.

        This function must be implemented and will return the name of this
        loss function. This name will be used to combine different loss items
        by simple sum operation. In addition, if you want this loss item to be
        included into the backward graph, `loss_` must be the prefix of the
        name.

        Returns:
            str: The name of this loss item.
        """
        return self._loss_name

注意必须要有loss_name方法,并且返回的loss_name需要以loss_作为前缀。

传入的pred和target的shape不一致,需要转为一致才可以直接调用nn.L1Loss()方法。
pred.shape: (n,c,h,w)
target.shape: (n,h,w)
所以需要将target转one-hot。转one-hot方法:index_select
(n,h,w) => (n,h,w,c) => (n,c,h,w)

拓展阅读:Pytorch中,将label变成one hot编码的两种方式

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值