hrnet模型推理进入这个接口
@auto_fp16(apply_to=('img', ))
def forward(self,
img,
target=None,
target_weight=None,
img_metas=None,
return_loss=True,
return_heatmap=False,
**kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
Note:
- batch_size: N
- num_keypoints: K
- num_img_channel: C (Default: 3)
- img height: imgH
- img width: imgW
- heatmaps height: H
- heatmaps weight: W
Args:
img (torch.Tensor[NxCximgHximgW]): Input images.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
target_weight (torch.Tensor[NxKx1]): Weights across
different joint types.
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
return_loss (bool): Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
return_heatmap (bool) : Option to return heatmap.
Returns:
dict|tuple: if `return loss` is true, then return losses. \
Otherwise, return predicted poses, boxes, image paths \
and heatmaps.
"""
if return_loss:#flase
return self.forward_train(img, target, target_weight, img_metas,
**kwargs)
return self.forward_test(
img, img_metas, return_heatmap=return_heatmap, **kwargs)
进入forward_test
def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
"""Defines the computation performed at every call when testing."""
assert img.size(0) == len(img_metas)
batch_size, _, img_height, img_width = img.shape#torch.Size([1, 3, 256, 192])
if batch_size > 1:#flase
assert 'bbox_id' in img_metas[0]
result = {}
features = self.backbone(img)#torch.Size([1, 32, 64, 48])
if self.with_neck:#false
features = self.neck(features)
if self.with_keypoint:#true (1, 17, 64, 48)
output_heatmap = self.keypoint_head.inference_model(
features, flip_pairs=None)
if self.test_cfg.get('flip_test', True):#true
img_flipped = img.flip(3)#1水平镜像2垂直,3盲猜水平和垂直
features_flipped = self.backbone(img_flipped)
if self.with_neck:
features_flipped = self.neck(features_flipped)
if self.with_keypoint:#true
output_flipped_heatmap = self.keypoint_head.inference_model(
features_flipped, img_metas[0]['flip_pairs'])
output_heatmap = (output_heatmap + output_flipped_heatmap)#平移后的heatmap和原先的heatmap相加
if self.test_cfg.get('regression_flip_shift', False):#false
output_heatmap[..., 0] -= 1.0 / img_width
output_heatmap = output_heatmap / 2
if self.with_keypoint:#true
keypoint_result = self.keypoint_head.decode(
img_metas, output_heatmap, img_size=[img_width, img_height])
result.update(keypoint_result)
if not return_heatmap:
output_heatmap = None
result['output_heatmap'] = output_heatmap
return result
onnx模型缺少的部分我们给他加上
#之前的onnx应该是转到了这一层
if self.test_cfg.get('flip_test', True):#true
img_flipped = img.flip(3)#1水平镜像2垂直,3盲猜水平和垂直
features_flipped = self.backbone(img_flipped)
if self.with_neck:
features_flipped = self.neck(features_flipped)
if self.with_keypoint:#true
output_flipped_heatmap = self.keypoint_head.inference_model(
features_flipped, img_metas[0]['flip_pairs'])
output_heatmap = (output_heatmap + output_flipped_heatmap)#平移后的heatmap和原先的heatmap相加
if self.test_cfg.get('regression_flip_shift', False):#false
output_heatmap[..., 0] -= 1.0 / img_width
output_heatmap = output_heatmap / 2
self.keypoint_head结构
TopdownHeatmapSimpleHead(
(loss): JointsMSELoss(
(criterion): MSELoss()
)
(deconv_layers): Identity()
(final_layer): Conv2d(32, 17, kernel_size=(1, 1), stride=(1, 1))
)
生成关键点代码
if self.with_keypoint:#true
keypoint_result = self.keypoint_head.decode(
img_metas, output_heatmap, img_size=[img_width, img_height])
result.update(keypoint_result)
这部分decode和我们搬运过来的是一致的
def decode(self, img_metas, output, **kwargs):
"""Decode keypoints from heatmaps.
Args:
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
output (np.ndarray[N, K, H, W]): model predicted heatmaps.
"""
batch_size = len(img_metas)
if 'bbox_id' in img_metas[0]:
bbox_ids = []
else:
bbox_ids = None
c = np.zeros((batch_size, 2), dtype=np.float32)
s = np.zeros((batch_size, 2), dtype=np.float32)
image_paths = []
score = np.ones(batch_size)
for i in range(batch_size):
c[i, :] = img_metas[i]['center']
s[i, :] = img_metas[i]['scale']
image_paths.append(img_metas[i]['image_file'])
if 'bbox_score' in img_metas[i]:
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
if bbox_ids is not None:
bbox_ids.append(img_metas[i]['bbox_id'])
preds, maxvals = keypoints_from_heatmaps(
output,
c,
s,
unbiased=self.test_cfg.get('unbiased_decoding', False),
post_process=self.test_cfg.get('post_process', 'default'),
kernel=self.test_cfg.get('modulate_kernel', 11),
valid_radius_factor=self.test_cfg.get('valid_radius_factor',
0.0546875),
use_udp=self.test_cfg.get('use_udp', False),
target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
all_preds[:, :, 0:2] = preds[:, :, 0:2]
all_preds[:, :, 2:3] = maxvals
all_boxes[:, 0:2] = c[:, 0:2]
all_boxes[:, 2:4] = s[:, 0:2]
all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
all_boxes[:, 5] = score
result = {}
result['preds'] = all_preds
result['boxes'] = all_boxes
result['image_paths'] = image_paths
result['bbox_ids'] = bbox_ids
return result