最近在使用SAM结合YOLO实现自动标注时,发现SAM可以同时生成3个掩膜并返回一个n*3的torch张量,记录了n个实例的3个掩膜对应分数;然后就想每次生成掩膜都返回分数最高的掩膜
@torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask."""
代码如下:
masks, scores, logits = self.SamPredictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=multimask,
)
# 高质量优化
if multimask:
n = masks.shape[0] # 获取masks的大小
max_scores_indices = torch.argmax(scores, dim=1) # 获取分数最高的索引
# 选择分数最高的掩膜
masks = masks[torch.arange(n), max_scores_indices]
masks形状为N*3*w*h,N为掩码数,3是默认生成掩膜数量;
scores形状为N*3,记录了每个实例的3个生成掩膜的分数;
torch.arange(n)形状为N,目的是索引masks的一维
max_scores_indices形状为N,记录分数最大的掩膜的索引,目的是索引masks的二维