模型结构
backbone dla34
dla(Deep Layer Aggregation)
We introduce two structures for deep layer aggregation (DLA): iterative deep aggrega-
tion (IDA) and hierarchical deep aggregation (HDA).
Hierarchical deep aggregation merges blocks and stages in a tree to preserve and combine feature channels.
我们介绍两种结构深层聚合(DLA):迭代深层聚合 (IDA)和层次深度聚合(HDA)。
IDA focuses on fusing resolutions and scales while HDA focuses on merging features from all modules and channels.
IDA主要负责不同空间尺度信息的融合,HDA侧重于合并来自所有模块和通道的特性。
HDA
基本卷积结构block
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
stride=stride, padding=dilation,
bias=False, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=dilation,
bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
构成具备树形结构的模块
class Root(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1,
stride=1, bias=False, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
x += children[0]
x = self.relu(x)
return x
class Tree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
level_root=False, root_dim=0, root_kernel_size=1,
dilation=1, root_residual=False):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.tree1 = block(in_channels, out_channels, stride,
dilation=dilation)
self.tree2 = block(out_channels, out_channels, 1,
dilation=dilation)
else:
self.tree1 = Tree(levels - 1, block, in_channels, out_channels,
stride, root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation, root_residual=root_residual)
self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation, root_residual=root_residual)
if levels == 1:
self.root = Root(root_dim, out_channels, root_kernel_size,
root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
)
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
HDA
IDA
IDAUp
IDAUp是IDA模块需要不断重复使用的操作,每一次操作都需要在上此次操作结果的基础上。如上图所示,每一层的结果都是由IDAUp生成,连续使用就形成了现在的金字塔结构
class IDAUp(nn.Module):
def __init__(self, o, channels, up_f):
super(IDAUp, self).__init__()
for i in range(1, len(channels)):
c = channels[i]
f = int(up_f[i])
proj = DeformConv(c, o)
node = DeformConv(o, o)
up = nn.ConvTranspose2d(o, o, f * 2, stride=f,
padding=f // 2, output_padding=0,
groups=o, bias=False)
fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
def forward(self, layers, startp, endp):
for i in range(startp + 1, endp):
upsample = getattr(self, 'up_' + str(i - startp))
project = getattr(self, 'proj_' + str(i - startp))
layers[i] = upsample(project(layers[i]))
node = getattr(self, 'node_' + str(i - startp))
layers[i] = node(layers[i] + layers[i - 1])
class DLAUp(nn.Module):
def __init__(self, startp, channels, scales, in_channels=None):
super(DLAUp, self).__init__()
self.startp = startp
if in_channels is None:
in_channels = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(self, 'ida_{}'.format(i),
IDAUp(channels[j], in_channels[j:],
scales[j:] // scales[j]))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
def forward(self, layers):
out = [layers[-1]] # start with 32
for i in range(len(layers) - self.startp - 1):
ida = getattr(self, 'ida_{}'.format(i))
ida(layers, len(layers) -i - 2, len(layers))
out.insert(0, layers[-1])
return out
DLA总体结构
class DLASeg(nn.Module):
def __init__(self, base_name, pretrained, down_ratio, final_kernel,
last_level, out_channel=0):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.first_level = int(np.log2(down_ratio)) # down_ratio=4
self.last_level = last_level
self.base = globals()[base_name](pretrained=pretrained)
channels = self.base.channels
scales = [2 ** i for i in range(len(channels[self.first_level:]))]
self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)
if out_channel == 0:
out_channel = channels[self.first_level]
self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level],
[2 ** i for i in range(self.last_level - self.first_level)])
def forward(self, x):
x = self.base(x)
x = self.dla_up(x)
y = []
for i in range(self.last_level - self.first_level):
y.append(x[i].clone())
self.ida_up(y, 0, len(y))
x = y[-1]
return x
head
class KeypointHead(nn.Module):
def __init__(self, intermediate_channel, head_conv):
super(KeypointHead, self).__init__()
self.hm = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0))
self.wh = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))
self.hps = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 34, kernel_size=1, stride=1, padding=0))
self.reg = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))
self.hm_hp = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 17, kernel_size=1, stride=1, padding=0))
self.hp_offset = nn.Sequential(
nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))
self.init_weights()
def forward(self, x):
return [self.hm(x), self.wh(x), self.hps(x), self.reg(x), self.hm_hp(x), self.hp_offset(x)]
train
dataloader
模型的输入大小(input_res)是512512,模型输出特征图(output_res)的大小是128128。
需要根据input_res做一次仿射变换,根据output_size做一次仿射变换,并得到两个仿射变换的矩阵。
groundtruth根据输出大小进行仿射变换后得到新的bounding box坐标点,该bounding box计算目标的中心点为正样本点,其他位置都是负样本。
通过目标中心点找到对应其他任务的输出结果(比如wh),计算loss。
目标的中心点和骨骼点对应的模型输出都是热力图形式的,根据目标的中心点进行高斯过滤的。
def __getitem__(self, index):
# get img_id through index
img_id = self.images[index]
# get img_name by img_id
file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
# get img_path by combining dataset_path and img_name
img_path = os.path.join(self.img_dir, file_name)
# get all annotation_ids through img_id
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
# get all annotations through ann_ids
anns = self.coco.loadAnns(ids=ann_ids)
# select annotions which category_id in self._valid_ids and is not crowd labeled
anns = list(filter(lambda x:x['category_id'] in self._valid_ids and x['iscrowd']!= 1 , anns))
# limit the numbers of objects in an image
num_objs = min(len(anns), self.max_objs)
# read the image
img = cv2.imread(img_path)
# get the property of attribute of this img
height, width = img.shape[0], img.shape[1]
# figure out the center of the image. shape=(x,y)
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
# the scale is defined as max edge
s = max(img.shape[0], img.shape[1]) * 1.0
# rotate ?
rot = 0
flipped = False
if self.split == 'train':
if self.cfg.DATASET.RANDOM_CROP: #true
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = self._get_border(128, img.shape[1])
h_border = self._get_border(128, img.shape[0])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
else:
# random adjust center and scale
sf = self.cfg.DATASET.SCALE
cf = self.cfg.DATASET.SHIFT
c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.cfg.DATASET.AUG_ROT:
rf = self.cfg.DATASET.ROTATE
rot = np.clip(np.random.randn()*rf, -rf*2, rf*2)
if np.random.random() < self.cfg.DATASET.FLIP:
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
# calculate the array which make the original image to input format
trans_input = get_affine_transform(
c, s, rot, [self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES])
# make the original img to input format size
inp = cv2.warpAffine(img, trans_input,
(self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES),
flags=cv2.INTER_LINEAR)
# uniformization
inp = (inp.astype(np.float32) / 255.)
if self.split == 'train' and not self.cfg.DATASET.NO_COLOR_AUG:
color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
# normalization
inp = (inp - np.array(self.cfg.DATASET.MEAN).astype(np.float32)) / np.array(self.cfg.DATASET.STD).astype(np.float32)
# adjust channels order
inp = inp.transpose(2, 0, 1)
output_res = self.cfg.MODEL.OUTPUT_RES
num_joints = self.num_joints
# calculate an array which make the original image to output size
trans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res])
# calculate an array which make the original image to output size rather than input format to output format
trans_output = get_affine_transform(c, s, 0, [output_res, output_res])
# calculate an array which make the original segmentation to output size
trans_seg_output = get_affine_transform(c, s, 0, [output_res, output_res])
# hm output target
hm = np.zeros((self.num_classes, output_res, output_res), dtype=np.float32)
# chekpoint heatmap output target
hm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32)
dense_kps = np.zeros((num_joints, 2, output_res, output_res),
dtype=np.float32)
dense_kps_mask = np.zeros((num_joints, output_res, output_res),
dtype=np.float32)
# all objects size
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
# keypoints offset for center point location in ouput fomat
kps = np.zeros((self.max_objs, num_joints * 2), dtype=np.float32)
# offset between centerpoint and centerpoint_init in output format
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
# the index of all object center in ouput format
ind = np.zeros((self.max_objs), dtype=np.int64)
# mask for real objects,default 32 objects in an image
reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
# Keypoints mask for all real keypoints which is visibal
kps_mask = np.zeros((self.max_objs, self.num_joints * 2), dtype=np.uint8)
hp_offset = np.zeros((self.max_objs * num_joints, 2), dtype=np.float32)
# keypoints index in ouput
hp_ind = np.zeros((self.max_objs * num_joints), dtype=np.int64)
# similar to kps_mask
hp_mask = np.zeros((self.max_objs * num_joints), dtype=np.int64)
# first draw gaussian for keypoints and then for the center point
draw_gaussian = draw_msra_gaussian if self.cfg.LOSS.MSE_LOSS else \
draw_umich_gaussian
gt_det = []
for k in range(num_objs):
ann = anns[k]
bbox = self._coco_box_to_bbox(ann['bbox'])
cls_id = int(ann['category_id']) - 1
pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3)
segment = self.coco.annToMask(ann)
if flipped:
bbox[[0, 2]] = width - bbox[[2, 0]] - 1
pts[:, 0] = width - pts[:, 0] - 1
for e in self.flip_idx:
pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy()
segment = segment[:, ::-1]
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox = np.clip(bbox, 0, output_res - 1)
segment= cv2.warpAffine(segment, trans_seg_output,
(output_res, output_res),
flags=cv2.INTER_LINEAR)
segment = segment.astype(np.float32)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if (h > 0 and w > 0) or (rot != 0):
# figure out gaussian radius
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
radius = self.cfg.hm_gauss if self.cfg.LOSS.MSE_LOSS else max(0, int(radius)) #后者
# work out object center in output format and type is float32
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
# int type for center location of object
ct_int = ct.astype(np.int32)
# label w and h for the number k objcet
wh[k] = 1. * w, 1. * h
# calculate the index for the center of the k_th object
ind[k] = ct_int[1] * output_res + ct_int[0] # object loacation idx
# calculate the diffience value for float center point and init center point to reduce discretization error
reg[k] = ct - ct_int # offset between centerpoint and centerpoint_init
reg_mask[k] = 1
#keypoint
num_kpts = pts[:, 2].sum()
if num_kpts == 0:
hm[cls_id, ct_int[1], ct_int[0]] = 0.9999
reg_mask[k] = 0
hp_radius = gaussian_radius((math.ceil(h), math.ceil(w)))
hp_radius = self.cfg.hm_gauss \
if self.cfg.LOSS.MSE_LOSS else max(0, int(hp_radius))
for j in range(num_joints):
if pts[j, 2] > 0:
pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot)
if pts[j, 0] >= 0 and pts[j, 0] < output_res and \
pts[j, 1] >= 0 and pts[j, 1] < output_res:
# offset between keypoints and centerpoint_init
kps[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_int
kps_mask[k, j * 2: j * 2 + 2] = 1
pt_int = pts[j, :2].astype(np.int32)
# offset between keypoints and keypoints_init
hp_offset[k * num_joints + j] = pts[j, :2] - pt_int
hp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0]
hp_mask[k * num_joints + j] = 1
if self.cfg.LOSS.DENSE_HP:
# must be before draw center hm gaussian
draw_dense_reg(dense_kps[j], hm[cls_id], ct_int,
pts[j, :2] - ct_int, radius, is_offset=True)
draw_gaussian(dense_kps_mask[j], ct_int, radius)
draw_gaussian(hm_hp[j], pt_int, hp_radius)
draw_gaussian(hm[cls_id], ct_int, radius)
gt_det.append([ct[0] - w / 2, ct[1] - h / 2,
ct[0] + w / 2, ct[1] + h / 2, 1] +
pts[:, :2].reshape(num_joints * 2).tolist() + [cls_id])
if rot != 0:
hm = hm * 0 + 0.9999
reg_mask *= 0
kps_mask *= 0
ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh,
'hps': kps, 'hps_mask': kps_mask}
if self.cfg.LOSS.DENSE_HP:
dense_kps = dense_kps.reshape(num_joints * 2, output_res, output_res)
dense_kps_mask = dense_kps_mask.reshape(
num_joints, 1, output_res, output_res)
dense_kps_mask = np.concatenate([dense_kps_mask, dense_kps_mask], axis=1)
dense_kps_mask = dense_kps_mask.reshape(
num_joints * 2, output_res, output_res)
ret.update({'dense_hps': dense_kps, 'dense_hps_mask': dense_kps_mask})
del ret['hps'], ret['hps_mask']
if self.cfg.LOSS.REG_OFFSET:
ret.update({'reg': reg})
if self.cfg.LOSS.HM_HP:
ret.update({'hm_hp': hm_hp})
if self.cfg.LOSS.REG_HP_OFFSET:
ret.update({'hp_offset': hp_offset, 'hp_ind': hp_ind, 'hp_mask': hp_mask})
if self.cfg.DEBUG > 0 or not self.split == 'train':
gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \
np.zeros((1, 40), dtype=np.float32)
meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id}
ret['meta'] = meta
return ret
loss
损失函数包括六个部分,目标中心点的heatmap,中心点的离散误差,目标的宽和高,骨骼点的heatmap,骨骼点的离散误差,骨骼点的偏移量,
class MultiPoseLoss(torch.nn.Module):
def __init__(self, cfg, local_rank):
super(MultiPoseLoss, self).__init__()
self.crit = FocalLoss() # hm
self.crit_hm_hp = FocalLoss() # hmhp
self.crit_kp = RegWeightedL1Loss() # keypoints offset
self.crit_reg = RegL1Loss() # wh,reg ,hp_offset
self.cfg = cfg
self.local_rank = local_rank
def forward(self, outputs, batch):
cfg = self.cfg
hm_loss, wh_loss, off_loss= 0, 0, 0
hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0
hm, wh, hps, reg, hm_hp, hp_offset = outputs
for s in range(cfg.MODEL.NUM_STACKS):
hm = _sigmoid(hm) # (16,1,128,128)
if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS:
hm_hp = _sigmoid(hm_hp) # (16,17,128,128)
# hm loss is calculate by focal loss
hm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKS
hp_loss += self.crit_kp(hps, batch['hps_mask'], # hps:(16,34,128,128)
batch['ind'], batch['hps']) / cfg.MODEL.NUM_STACKS
if cfg.LOSS.WH_WEIGHT > 0:
# use center index to find center location and find wh to calculate loss
wh_loss += self.crit_reg(wh, batch['reg_mask'],
batch['ind'], batch['wh']) / cfg.MODEL.NUM_STACKS
if cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true
off_loss += self.crit_reg(reg, batch['reg_mask'],
batch['ind'], batch['reg']) / cfg.MODEL.NUM_STACKS
if cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true
# use keypoints index to calculate keypoints discretization error
hp_offset_loss += self.crit_reg(
hp_offset, batch['hp_mask'],
batch['hp_ind'], batch['hp_offset']) / cfg.MODEL.NUM_STACKS
if cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0:
hm_hp_loss += self.crit_hm_hp(
hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKS
loss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \
cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \
cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_loss
loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss,
'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
'wh_loss': wh_loss, 'off_loss': off_loss}
return loss, loss_stats
heatmap损失计算
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
eval
topk
先选择出每个类别得分topk的点,然后再把这些点放在一起选出topk
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
# select topk values of each category
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) # topk_inds => batch x cat x K
topk_inds = topk_inds % (height * width)
# calculate location for each categories using inds
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
# select topk of all categories
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # topk_ind => batch x K
topk_clses = (topk_ind / K).int()
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
post process
1,先由hm找到目标的中心点,然后再找到对应的reg和wh计算出bbox。
2,骨骼点有kps和hmhp共同确定,kps知道自己属于哪个目标但是精度不高,而hmhp精度高,但不知道自己属于哪个个体。
def whole_body_decode(
heat, wh, kps, seg_feat=None, seg=None, reg=None, hm_hp=None, hp_offset=None, K=100):
batch, cat, height, width = heat.size()
num_joints = kps.shape[1] // 2
# perform nms on heatmaps
heat = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
kps = _transpose_and_gather_feat(kps, inds)
kps = kps.view(batch, K, num_joints * 2)
kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)
kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)
if reg is not None:
reg = _transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _transpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 2)
weight = _transpose_and_gather_feat(seg, inds)
## you can write (if weight.size(1)!=seg_feat.size(1): 3x3conv else 1x1conv ) here to select seg conv.
## for 3x3
weight = weight.view([weight.size(1), -1, 3, 3])
pred_seg = F.conv2d(seg_feat, weight, stride=1, padding=1)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
if hm_hp is not None:
hm_hp = _nms(hm_hp)
thresh = 0.1
kps = kps.view(batch, K, num_joints, 2).permute(
0, 2, 1, 3).contiguous() # b x K x 34 => b x J x K x 2
# reg_kps represent duplicate (b,j,k,1,2) k times is diffierent from duplicate (b,j,1,k,2) k times like hm_kps
reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
# find max scores of each joints(17) and its response index,ys,xs
hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
# use hp_offset to make the position more precise
if hp_offset is not None:
hp_offset = _transpose_and_gather_feat(
hp_offset, hm_inds.view(batch, -1))
hp_offset = hp_offset.view(batch, num_joints, K, 2)
hm_xs = hm_xs + hp_offset[:, :, :, 0]
hm_ys = hm_ys + hp_offset[:, :, :, 1]
else:
hm_xs = hm_xs + 0.5
hm_ys = hm_ys + 0.5
# use thresh to make mask
mask = (hm_score > thresh).float()
# use mask to select hm_score,hm_ys,hm_xs where hm_score >= thresh
hm_score = (1 - mask) * -1 + mask * hm_score
hm_ys = (1 - mask) * (-10000) + mask * hm_ys
hm_xs = (1 - mask) * (-10000) + mask * hm_xs
# hm_kps represents the keypoints produced by joint heatmap
hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
2).expand(batch, num_joints, K, K, 2)
# figure out the distance between hm_kps and reg_kps
dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)
#
min_dist, min_ind = dist.min(dim=3) # b x J x K
hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1
min_dist = min_dist.unsqueeze(-1)
min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
batch, num_joints, K, 1, 2)
hm_kps = hm_kps.gather(3, min_ind)
hm_kps = hm_kps.view(batch, num_joints, K, 2)
l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
(hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
(hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
mask = (mask > 0).float().expand(batch, num_joints, K, 2)
kps = (1 - mask) * hm_kps + mask * kps
kps = kps.permute(0, 2, 1, 3).contiguous().view(
batch, K, num_joints * 2)
detections = torch.cat([bboxes, scores, kps, torch.transpose(hm_score.squeeze(dim=3), 1, 2)], dim=2)
return detections, pred_seg