接上面部分,对五六七部分进行详解,这篇介绍第六部分。
一、回顾
第六部分对得到得dets进行后处理:
dets = self.post_process(dets, meta, scale)
torch.cuda.synchronize()
post_process_time = time.time()
post_time += post_process_time - decode_time
detections.append(dets)
post_process在ctdet.py中出现:
def post_process(self, dets, meta, scale=1):
dets = dets.detach().cpu().numpy()
dets = dets.reshape(1, -1, dets.shape[2])
dets = ctdet_post_process(
dets.copy(), [meta['c']], [meta['s']],
meta['out_height'], meta['out_width'], self.opt.num_classes)
for j in range(1, self.num_classes + 1):
dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
dets[0][j][:, :4] /= scale
return dets[0]
使用ctdet_post_process进行后处理,
最后得到得dets是一个len=80的张量(列表?)
其中每个元素是一个N * 5的ndarray。猜测是每个类别中的dets。detections是一个包含dets的列表。
二、详解
主要是ctdet_post_process部分
输入是这些:
ctdet_post_process来源于utils/post_process,代码和解释如下:
def ctdet_post_process(dets, c, s, h, w, num_classes):
# dets: batch x max_dets x dim
# return 1-based class det dict
ret = []
# 对于每个batch
for i in range(dets.shape[0]):
top_preds = {}
# 输入dets的左上角,中心点,最长边,(128, 128):heatmap的长宽
# 得到变换后的dets,怎么变换的还未知
dets[i, :, :2] = transform_preds(
dets[i, :, 0:2], c[i], s[i], (w, h))
# 输入dets的右上角,。。。同理
dets[i, :, 2:4] = transform_preds(
dets[i, :, 2:4], c[i], s[i], (w, h))
classes = dets[i, :, -1]
# 将第j个类的结果放到top_preds[j+1]中,top_preds是一个dict
for j in range(num_classes):
inds = (classes == j)
top_preds[j + 1] = np.concatenate([
dets[i, inds, :4].astype(np.float32),
dets[i, inds, 4:5].astype(np.float32)], axis=1).tolist()
# 将一张图片的结果放到ret中
ret.append(top_preds)
return ret