代码:https://github.com/xingyizhou/CenterNet
可参考:https://www.cnblogs.com/pprp/p/13404184.html
一:数据处理
图片标签处理的位置:CenterNet/src/lib/datasets/sample/ctdet.py
coco数据集为例
def __getitem__(self, index):
img_id = self.images[index]
file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(self.img_dir, file_name)
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
anns = self.coco.loadAnns(ids=ann_ids)
num_objs = min(len(anns), self.max_objs)
img = cv2.imread(img_path)
# cv2.imwrite('input.jpg', img)
height, width = img.shape[0], img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
if self.opt.keep_res:
input_h = (height | self.opt.pad) + 1
input_w = (width | self.opt.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
else:
s = max(img.shape[0], img.shape[1]) * 1.0
input_h, input_w = self.opt.input_h, self.opt.input_w
flipped = False
if self.split == 'train':
if not self.opt.not_rand_crop:
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = self._get_border(128, img.shape[1])
h_border = self._get_border(128, img.shape[0])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
else:
sf = self.opt.scale
cf = self.opt.shift
c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.opt.flip:
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
trans_input = get_affine_transform(
c, s, 0, [input_w, input_h])
inp = cv2.warpAffine(img, trans_input,
(input_w, input_h),
flags=cv2.INTER_LINEAR)
inp = (inp.astype(np.float32) / 255.)
if self.split == 'train' and not self.opt.no_color_aug:
color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
inp = (inp - self.mean) / self.std
inp = inp.transpose(2, 0, 1)
output_h = input_h // self.opt.down_ratio
output_w = input_w // self.opt.down_ratio
num_classes = self.num_classes
trans_output = get_affine_transform(c, s, 0, [output_w, output_h])
hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
ind = np.zeros((self.max_objs), dtype=np.int64)
reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
cat_spec_wh = np.zeros((self.max_objs, num_classes * 2), dtype=np.float32)
cat_spec_mask = np.zeros((self.max_objs, num_classes * 2), dtype=np.uint8)
draw_gaussian = draw_msra_gaussian if self.opt.mse_loss else \
draw_umich_gaussian
gt_det = []
for k in range(num_objs):
ann = anns[k]
bbox = self._coco_box_to_bbox(ann['bbox'])
cls_id = int(self.cat_ids[ann['category_id']])
if flipped:
bbox[[0, 2]] = width - bbox[[2, 0]] - 1
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h > 0 and w > 0:
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
radius = max(0, int(radius))
radius = self.opt.hm_gauss if self.opt.mse_loss else radius
ct = np.array(
[(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
ct_int = ct.astype(np.int32)
draw_gaussian(hm[cls_id], ct_int, radius)
wh[k] = 1. * w, 1. * h
ind[k] = ct_int[1] * output_w + ct_int[0]
reg[k] = ct - ct_int
reg_mask[k] = 1
cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]
cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
if self.opt.dense_wh:
draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)
gt_det.append([ct[0] - w / 2, ct[1] - h / 2,
ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])
###把图片变换成对应的热图hm,计算出wh和
# cv2.imwrite('input-1.jpg', inp.transpose(1, 2, 0))
# cv2.imwrite('hm.jpg', hm.transpose(1, 2, 0)*254)
ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh}
if self.opt.dense_wh:
hm_a = hm.max(axis=0, keepdims=True)
dense_wh_mask = np.concatenate([hm_a, hm_a], axis=0)
ret.update({'dense_wh': dense_wh, 'dense_wh_mask': dense_wh_mask})
del ret['wh']
elif self.opt.cat_spec_wh:
ret.update({'cat_spec_wh': cat_spec_wh, 'cat_spec_mask': cat_spec_mask})
del ret['wh']
if self.opt.reg_offset:
ret.update({'reg': reg})
if self.opt.debug > 0 or not self.split == 'train':
gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \
np.zeros((1, 6), dtype=np.float32)
meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id}
ret['meta'] = meta
return ret
图片(预处理后得到的512512):
处理后的热图,即代码中计算出的hm(128128,因为进行了4倍下采样,得到的是512/4=128大小的热图):
二:结果处理
代码位置:CenterNet/src/lib/trains/ctdet.py
def debug(self, batch, output, iter_id): ##output是模型的输出,batch是标签
opt = self.opt
reg = output['reg'] if opt.reg_offset else None
dets = ctdet_decode(
output['hm'], output['wh'], reg=reg,
cat_spec_wh=opt.cat_spec_wh, K=opt.K) ##输出进行解码
dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
dets[:, :, :4] *= opt.down_ratio
dets_gt = batch['meta']['gt_det'].numpy().reshape(1, -1, dets.shape[2])
dets_gt[:, :, :4] *= opt.down_ratio
for i in range(1):
debugger = Debugger(
dataset=opt.dataset, ipynb=(opt.debug==3), theme=opt.debugger_theme)
img = batch['input'][i].detach().cpu().numpy().transpose(1, 2, 0)
img = np.clip(((
img * opt.std + opt.mean) * 255.), 0, 255).astype(np.uint8)
pred = debugger.gen_colormap(output['hm'][i].detach().cpu().numpy())
gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
debugger.add_blend_img(img, pred, 'pred_hm')
debugger.add_blend_img(img, gt, 'gt_hm')
debugger.add_img(img, img_id='out_pred')
for k in range(len(dets[i])):
if dets[i, k, 4] > opt.center_thresh:
debugger.add_coco_bbox(dets[i, k, :4], dets[i, k, -1],
dets[i, k, 4], img_id='out_pred') ##输出画框
debugger.add_img(img, img_id='out_gt')
for k in range(len(dets_gt[i])):
if dets_gt[i, k, 4] > opt.center_thresh:
debugger.add_coco_bbox(dets_gt[i, k, :4], dets_gt[i, k, -1],
dets_gt[i, k, 4], img_id='out_gt') ##标签画框
if opt.debug == 4:
debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id)) ##把画的框都保存成图片
else:
debugger.show_all_imgs(pause=True)
保存的图片结果:
标签图:
输出图:
centernet训练不充分的时候得分很低?