语义分割任务中如何处理label为255的标签

语义分割常用数据集Cityscapes中会将不需要用到的像素标签设置为255,但初学者可能会遇到困惑,我们在训练或者评估的时候遇到255的标签该怎么办呢?我们需要做的是忽略。

训练计算loss时的处理

import torch
from torch import nn

class CrossEntropy2d(nn.Module):
    def __init__(self, ignore_label=255):
        super().__init__()
        self.ignore_label = ignore_label

    def forward(self, predict, target):
        """
        :param predict: [batch, num_class, height, width]
        :param target: [batch, height, width]
        :return: entropy loss
        """
        target_mask = target != self.ignore_label  # [batch, height, width]筛选出所有需要训练的像素点标签
        target = target[target_mask]  # [num_pixels]
        batch, num_class, height, width = predict.size()
        predict = predict.permute(0, 2, 3, 1)  # [batch, height, width, num_class]
        predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
        loss = F.cross_entropy(predict, target)
        return loss	

上面代码的核心就是通过索引将需要训练的像素点拿出来进行交叉熵损失的计算

评估计算Pixel accuracy 和Mean IoU

def eval_metrics(predict, target, ignore_label=255):
    # 预处理 将ignore label对应的像素点筛除
    target_mask = (target != ignore_label)  # [batch, height, width]筛选出所有需要训练的像素点标签
    target = target[target_mask]  # [num_pixels]
    batch, num_class, height, width = predict.size()
    predict = predict.permute(0, 2, 3, 1)  # [batch, height, width, num_class]
    
    # 计算pixel accuracy
    predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
    predict = predict.argmax(dim=1)
    num_pixels = target.numel()
    correct = (predict == target).sum()
    pixel_acc = correct / num_pixels
    
    # 计算所有类别的mIoU
    predict = predict + 1
    target = target + 1
    intersection = predict * (predict == target).long()
    area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
    area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
    area_label = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
    mIoU = area_inter.mean() / (area_pred + area_label - area_inter).mean()
    return pixel_acc, mIoU
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
分析这个代码class OhemCrossEntropy(nn.Module): def __init__(self, ignore_label=-1, thres=0.7, min_kept=100000, weight=None): super(OhemCrossEntropy, self).__init__() self.thresh = thres self.min_kept = max(1, min_kept) self.ignore_label = ignore_label self.criterion = nn.CrossEntropyLoss( weight=weight, ignore_index=ignore_label, reduction='none' ) def _ce_forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.interpolate(input=score, size=( h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) loss = self.criterion(score, target) return loss def _ohem_forward(self, score, target, **kwargs): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.interpolate(input=score, size=( h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) pred = F.softmax(score, dim=1) pixel_losses = self.criterion(score, target).contiguous().view(-1) mask = target.contiguous().view(-1) != self.ignore_label tmp_target = target.clone() tmp_target[tmp_target == self.ignore_label] = 0 pred = pred.gather(1, tmp_target.unsqueeze(1)) pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() min_value = pred[min(self.min_kept, pred.numel() - 1)] threshold = max(min_value, self.thresh) pixel_losses = pixel_losses[mask][ind] pixel_losses = pixel_losses[pred < threshold] return pixel_losses.mean() def forward(self, score, target): if config.MODEL.NUM_OUTPUTS == 1: score = [score] weights = config.LOSS.BALANCE_WEIGHTS assert len(weights) == len(score) functions = [self._ce_forward] * \ (len(weights) - 1) + [self._ohem_forward] return sum([ w * func(x, target) for (w, x, func) in zip(weights, score, functions) ])
04-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值