语义分割常用数据集Cityscapes中会将不需要用到的像素标签设置为255,但初学者可能会遇到困惑,我们在训练或者评估的时候遇到255的标签该怎么办呢?我们需要做的是忽略。
训练计算loss时的处理
import torch
from torch import nn
class CrossEntropy2d(nn.Module):
def __init__(self, ignore_label=255):
super().__init__()
self.ignore_label = ignore_label
def forward(self, predict, target):
"""
:param predict: [batch, num_class, height, width]
:param target: [batch, height, width]
:return: entropy loss
"""
target_mask = target != self.ignore_label # [batch, height, width]筛选出所有需要训练的像素点标签
target = target[target_mask] # [num_pixels]
batch, num_class, height, width = predict.size()
predict = predict.permute(0, 2, 3, 1) # [batch, height, width, num_class]
predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
loss = F.cross_entropy(predict, target)
return loss
上面代码的核心就是通过索引将需要训练的像素点拿出来进行交叉熵损失的计算
评估计算Pixel accuracy 和Mean IoU
def eval_metrics(predict, target, ignore_label=255):
# 预处理 将ignore label对应的像素点筛除
target_mask = (target != ignore_label) # [batch, height, width]筛选出所有需要训练的像素点标签
target = target[target_mask] # [num_pixels]
batch, num_class, height, width = predict.size()
predict = predict.permute(0, 2, 3, 1) # [batch, height, width, num_class]
# 计算pixel accuracy
predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
predict = predict.argmax(dim=1)
num_pixels = target.numel()
correct = (predict == target).sum()
pixel_acc = correct / num_pixels
# 计算所有类别的mIoU
predict = predict + 1
target = target + 1
intersection = predict * (predict == target).long()
area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
area_label = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
mIoU = area_inter.mean() / (area_pred + area_label - area_inter).mean()
return pixel_acc, mIoU