本文主要对yolov8-seg中预测的后处理代码进行示例解读。总体过程可为图像先经过预处理,之后输入模型得到预测的结果。然后先对检测的结果进行置信度和nms处理,得到有效的目标,之后然后根据有效目标筛选mask中的有效mask。最后经过颜色处理和图象重构恢复到输入图像的大小,并带有mask信息。注意:yolov8和之前的一个区别,在推理输出的pred中,置信的和目标分数合二为一了,即原来是85,现在应该是84。
前期的数据和模型家在部分,本文不再赘述,主要通过示例解读后处理部分,从predictor.py中的下图部分开始。
如上图中,preds是包含两个内容,第一个是检测结果,第二个是分割结果。执行self.postprocess,则进入发哦下图代码。
提取预测输出的第一个数据,(1,37,7140)进行非极大值抑制,在segment/predict.py文件夹中。
然后进入ops.py中的non_max_suppression模块。
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[1] - nm - 4 # number of classes
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.transpose(0, -1)[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
代码解读:
先判断设置的置信度和iou阈值是否在0和1之间,如不在,则报错是不符合要求的阈值。
判读那输入的prediction类型,如果之前没有进行提取,则此处再进行提取,如提取过,则直接跳过。
读取尺寸大小bs,本文是1
获取类别数目nc,本文是1,其中nm为32是mask的数量
mi为读取mask开始的位置,即数组前边是框、类别置信度,然后是mask
xc为根据置信度分数从7140个结果中筛选结果,即分数大于置信度的为True,小于的为False。
经过设置的值,如max_wh和max_nms等一系列常数。
设置一个0行38列的数组output。
依次从prediction中读取数据,以图像为单位,即如果只有一幅图像,则xi=0,只有一个x,for xi, x in enumerate(prediction):
。
根据xc的true和false筛选prediction中符合要求的目标。如本文prediction为(1,37,7140),刷选后为x形状为(65,37),表示还有65个满足置信度的结果。
获取标签
如果经过置信度,没有满足要求的结果,则直接退出此模块, if not x.shape[0]: continue
把x划分为框box、类别cla和mask,box, cls, mask = x.split((4, nc, nm), 1)
,其中x.split(data,1),1表示分割成两个,分割一次
转换坐标框的表示方式(xywh2xyxy)
判断是否为多类别,如本文为单类别,直接进行但类别的。conf, j = cls.max(1, keepdim=True)
返回维度中最大的值,并返回位置。
返回分数conf和最大值位置j,对x进行拼接,由box、conf,j和mask组成。x变为(65,38)
判断类别是否为空。
再次判断x的形状,看是否有结果,如无则跳出此模块。
根据分数conf,对结果进行排序,并根据设置的最大nms个数,截取分高较高的一部分。一般情况下设置数值较大,所有目标都可以被保留。
c是类别位置乘以相应数值,因为本文最大值位置均在0,因此c中的值全为0。
输入boxes、scores和iou阈值进行非极大值抑制。得到的i是坐标索引,即第几个目标符合要求,本文是有7个。
根据最大的检测个数截取符合要求的目标。一般情况下,都可以符合。
根据i的索引从x中获取符合nms处理的目标,并保存到output中,output[xi] = x[i]
。
到此处,此幅图像经过了置信度阈值判断和非极大值抑制处理,确定了图像中的目标类别、分数和框的位置。也在此模块输出,退出此模块,重新进入segment/predict.py中的postprocess模块中
p为输出的结果,形状为(7,38),即由7个目标符合要求。
接下来针对实例分割的结果。
def postprocess(self, preds, img, orig_img):
masks = []
# TODO: filter by classes
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nm=32)
proto = preds[1][-1]
for i, pred in enumerate(p):
shape = orig_img[i].shape if self.webcam else orig_img.shape
if not len(pred):
continue
if self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
else:
masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
return (p, masks)
从预测输出中提取结果,proto,形状为(1,32,136,160),如果输入预测图像大小为(640,640),此处应为(1,32,160,160)。
读取经过置信度和非极大值筛选的数值。得到pred形状为(7,38)。
获取原图像形状shape,如本文的是(1024,1224,3)
判断是否有预测结果,若无则跳出。
对分割结果进行后处理,输入的是proto、pred:,6:,pred[:,:4]为框坐标img.shape[:2]为(544,640)四个数据
进入到ops.py中的process_mask模块。
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
"""
It takes the output of the mask head, and applies the mask to the bounding boxes. This is faster but produces
downsampled quality of mask
Args:
protos (torch.tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
(torch.tensor): The processed masks.
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= mw / iw
downsampled_bboxes[:, 2] *= mw / iw
downsampled_bboxes[:, 3] *= mh / ih
downsampled_bboxes[:, 1] *= mh / ih
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
return masks.gt_(0.5)
代码如上,解读如下:
获取protos的形状,本文中c为32,mh为136,mw为160。
获取模型输入图像的大小,ih为544,iw为640。
进行相乘和形状重构,得到masks。masks_in是(7,32),protos是(32,136,160)
downsampled_bboxes为复制的7个目标框坐标。
然后把坐标从(136,160)映射到(544,640)得到映射后的框坐标downsampled_bboxes。
之后进行剪切掩码,**进入到ops.py的crop_mask模块。**输入的是masks和downsampled_bboxes
def crop_mask(masks, boxes):
"""
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
Args:
masks (torch.tensor): [h, w, n] tensor of masks
boxes (torch.tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
(torch.tensor): The masks are being cropped to the bounding box.
"""
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
获取masks的形状,n为7,h是136,w是160.
然后对boxes的内容进行拆分,torch.chunk(),拆分为四个,x1, y1, x2, y2
确保masks的值在图像大小范围之内。
对masks的形状进行上采样,由(7,136,160)变为(7,544,640),变的和模型输入图像大小一样。
masks.gt_(0.5)判断masks中的值是否大于0.5,大于则为true。
此模块结束,重新进入segment/predict.py中的postprocess模块中
返回的masks,形状为(7,544,640),然后和执行如下代码。
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
对框坐标进行处理。
至此,segment/predict.py中的postprocess模块中执行完毕,返回到predictor.py中。
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
with self.dt[0]:
im = self.preprocess(im)
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
with self.dt[1]:
preds = model(im, augment=self.args.augment, visualize=visualize)
# postprocess
with self.dt[2]:
preds = self.postprocess(preds, im, im0s)
for i in range(len(im)):
if self.webcam:
path, im0s = path[i], im0s[i]
p = Path(path)
s += self.write_results(i, preds, (p, im, im0s))
返回得到的preds是两个列表,第一个列表形状是(7,38),第二个是(7,544,640)。
之后的就是写入结果。
进入到segment/predict.py中,
def write_results(self, idx, preds, batch):
p, im, im0 = batch
log_string = ""
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
if self.webcam: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
self.data_path = p
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
log_string += '%gx%g ' % im.shape[2:] # print string
self.annotator = self.get_annotator(im0)
preds, masks = preds
det = preds[idx]
if len(det) == 0:
return log_string
# Segments
mask = masks[idx]
if self.args.save_txt or self.return_outputs:
shape = im0.shape if self.args.retina_masks else im.shape[2:]
segments = [
ops.scale_segments(shape, x, im0.shape, normalize=False) for x in reversed(ops.masks2segments(mask))]
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string
# Mask plotting
self.annotator.masks(
mask,
colors=[colors(x, True) for x in det[:, 5]],
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
255 if self.args.retina_masks else im[idx])
det = reversed(det[:, :6])
if self.return_outputs:
self.output["det"] = det.cpu().numpy()
self.output["segment"] = segments
# Write results
for j, (*xyxy, conf, cls) in enumerate(det):
if self.args.save_txt: # Write to file
seg = segments[j].copy()
seg[:, 0] /= shape[1] # width
seg[:, 1] /= shape[0] # height
seg = seg.reshape(-1) # (n,2) to (n*2)
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format
with open(f'{self.txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if self.args.save or self.args.save_crop or self.args.show:
c = int(cls) # integer class
label = None if self.args.hide_labels else (
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
self.annotator.box_label(xyxy, label, color=colors(c, True))
# annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
if self.args.save_crop:
imc = im0.copy()
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True)
return log_string
从preds, masks = preds
分解preds开始,分别获得preds和masks,其中preds是预测结果,形状为(7,38),masks是(7,544,640)。
掩码处理,进入到plotting.py中的
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
"""Plot masks at once.
Args:
masks (tensor): predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
"""
if self.pil:
# convert to numpy first
self.im = np.asarray(self.im).copy()
if len(masks) == 0:
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
im_mask = (im_gpu * 255)
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else scale_image(im_gpu.shape, im_mask_np, self.im.shape)
if self.pil:
# convert im back to PIL and update draw
self.fromarray(self.im)
输入的是masks(7,544,640),colors,此处是7个目标,但同属一个类别,用的颜色一样,im_gpu(3,544,640)。
对colors的数值除以255,并进行维度处理,变为(7,1,1,3)
对masks也进行维度扩充,变为(7,544,640,1)
masks乘以颜色和相应的参数,得masks_color,维度为(7,544,640,3)
进行数值和维度处理,最后得到mcs,维度为(544,640,3)
之后对im_gpu和mcs进行结合处理,并进行数值类型转换,得到im_mask_np
对图像尺寸进行处理,重构为和原始图像一样大小,并含有mask。进入ops.py中的scale_image模块。
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
"""
Takes a mask, and resizes it to the original image size
Args:
im1_shape (tuple): model input shape, [h, w]
masks (torch.tensor): [h, w, num]
im0_shape (tuple): the original image shape
ratio_pad (tuple): the ratio of the padding to the original image.
Returns:
masks (torch.tensor): The masks that are being returned.
"""
# Rescale coordinates (xyxy) from im1_shape to im0_shape
if ratio_pad is None: # calculate from im0_shape
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
else:
pad = ratio_pad[1]
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
masks = masks[top:bottom, left:right]
# masks = masks.permute(2, 0, 1).contiguous()
# masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
# masks = masks.permute(1, 2, 0).contiguous()
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
if len(masks.shape) == 2:
masks = masks[:, :, None]
return masks
计算缩放率和填充大小,如本文中缩放率gain是0.5228,pad分别是0和4.28。
计算上下左右的边,得到原图像重构后的大小,即没有灰边的大小。
根据坐标,从masks中截取,得到的新masks形状为(535,640,3),原来为(544,640,3)
在masks的基础上resize重构图像,恢复到原始图像的大小(1024,1224,3)
返回到plotting.py中,把重构后的图像保存到self输出,返回到segment/predict.py中。
接下来就是保存到相关txt或者别的形式。