SAM 掩膜生成器学习笔记,从多个生成掩膜中得到分数最高的掩膜

最近在使用SAM结合YOLO实现自动标注时,发现SAM可以同时生成3个掩膜并返回一个n*3的torch张量,记录了n个实例的3个掩膜对应分数;然后就想每次生成掩膜都返回分数最高的掩膜


@torch.no_grad()
    def predict_torch(
        self,
        point_coords: Optional[torch.Tensor],
        point_labels: Optional[torch.Tensor],
        boxes: Optional[torch.Tensor] = None,
        mask_input: Optional[torch.Tensor] = None,
        multimask_output: bool = True,
        return_logits: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """multimask_output (bool): If true, the model will return three masks.
            For ambiguous input prompts (such as a single click), this will often
            produce better masks than a single prediction. If only a single
            mask is needed, the model's predicted quality score can be used
            to select the best mask. For non-ambiguous prompts, such as multiple
            input prompts, multimask_output=False can give better results.
          return_logits (bool): If true, returns un-thresholded masks logits
            instead of a binary mask."""

代码如下:

masks, scores, logits = self.SamPredictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=multimask,
        )
        # 高质量优化
        if multimask:
            n = masks.shape[0]  # 获取masks的大小
            max_scores_indices = torch.argmax(scores, dim=1)  # 获取分数最高的索引
            # 选择分数最高的掩膜
            masks = masks[torch.arange(n), max_scores_indices]

masks形状为N*3*w*h,N为掩码数,3是默认生成掩膜数量;
scores形状为N*3,记录了每个实例的3个生成掩膜的分数;
torch.arange(n)形状为N,目的是索引masks的一维
max_scores_indices形状为N,记录分数最大的掩膜的索引,目的是索引masks的二维

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值