分割模型OHEM

基于pytorch分割模型使用OHEM

OHEM在分割模型的应用,让模型更专注于难分类样本点(预测的概率值低)

def sample(self, seg_logit, seg_label):
   """Sample pixels that have high loss or with low prediction confidence.

    Args:
        seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
        seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)

    Returns:
        torch.Tensor: segmentation weight, shape (N, H, W)
    """
    with torch.no_grad():
        assert seg_logit.shape[2:] == seg_label.shape[2:]
        assert seg_label.shape[1] == 1
        seg_label = seg_label.squeeze(1).long()
        batch_kept = self.min_kept * seg_label.size(0)
        valid_mask = seg_label != self.context.ignore_index
        # 生成和seg_logit相同类型、相同device的0 tensor,形状和seg_label相同:(N, 1, H, W)
        seg_weight = seg_logit.new_zeros(size=seg_label.size())
        # 将seg_weight拉直为Nx1xHxW
        valid_seg_weight = seg_weight[valid_mask]
		# 使用softmax对预测值计算概率值
        seg_prob = F.softmax(seg_logit, dim=1)

        tmp_seg_label = seg_label.clone().unsqueeze(1)
        tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
        # 通过gather在第1维按照索引tmp_seg_label取值
        seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
        sort_prob, sort_indices = seg_prob[valid_mask].sort()
		# 按照最低数量和排序重新计算阈值
        if sort_prob.numel() > 0:
            min_threshold = sort_prob[min(batch_kept,
                                          sort_prob.numel() - 1)]
        else:
            min_threshold = 0.0
        threshold = max(min_threshold, self.thresh)
        valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
        # 按照索引重新赋值回去
        seg_weight[valid_mask] = valid_seg_weight

        return seg_weight
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值