【代码】CenterNet使用(续)(对五六七部分详解)(六)

接上面部分,对五六七部分进行详解,这篇介绍第六部分。

一、回顾

第六部分对得到得dets进行后处理:

      dets = self.post_process(dets, meta, scale)
      torch.cuda.synchronize()
      post_process_time = time.time()
      post_time += post_process_time - decode_time

      detections.append(dets)

post_process在ctdet.py中出现:

  def post_process(self, dets, meta, scale=1):
    dets = dets.detach().cpu().numpy()
    dets = dets.reshape(1, -1, dets.shape[2])
    dets = ctdet_post_process(
        dets.copy(), [meta['c']], [meta['s']],
        meta['out_height'], meta['out_width'], self.opt.num_classes)
    for j in range(1, self.num_classes + 1):
      dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
      dets[0][j][:, :4] /= scale
    return dets[0]

使用ctdet_post_process进行后处理,

最后得到得dets是一个len=80的张量(列表?)

 其中每个元素是一个N * 5的ndarray。猜测是每个类别中的dets。detections是一个包含dets的列表。

二、详解

主要是ctdet_post_process部分

输入是这些:

ctdet_post_process来源于utils/post_process,代码和解释如下:

def ctdet_post_process(dets, c, s, h, w, num_classes):
  # dets: batch x max_dets x dim
  # return 1-based class det dict
  ret = []
  # 对于每个batch
  for i in range(dets.shape[0]):
    top_preds = {}
    # 输入dets的左上角,中心点,最长边,(128, 128):heatmap的长宽
    # 得到变换后的dets,怎么变换的还未知
    dets[i, :, :2] = transform_preds(
          dets[i, :, 0:2], c[i], s[i], (w, h))
    # 输入dets的右上角,。。。同理
    dets[i, :, 2:4] = transform_preds(
          dets[i, :, 2:4], c[i], s[i], (w, h))
    classes = dets[i, :, -1]
    # 将第j个类的结果放到top_preds[j+1]中,top_preds是一个dict
    for j in range(num_classes):
      inds = (classes == j)
      top_preds[j + 1] = np.concatenate([
        dets[i, inds, :4].astype(np.float32),
        dets[i, inds, 4:5].astype(np.float32)], axis=1).tolist()
    # 将一张图片的结果放到ret中
    ret.append(top_preds)
  return ret

 

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值