batch_size = x.shape[0]
all_roi_align_feats = []
for i in range(batch_size):
hmap1_s = out_hmap_1[i, :, :, :]
hmap1_s = hmap1_s.unsqueeze(0)
# print('hmap1_s: ', hmap1_s.shape)
regs1_s = regs[i, :, :, :]
regs1_s = regs1_s.unsqueeze(0)
dets = ctdet_decode(hmap1_s, regs1_s, K=100)
# print('dets: ', dets.shape)
roi_align_feats_0 = []
roi_align_feats_1 = []
roi_align_feats_2 = []
for index, box in enumerate(dets):
box_list = []
# print(dets[0, index, :4])
box_list.append(dets[:, index, :4])
cls_index = int(dets[:, index, -1].cpu().detach().numpy()[0])
# print('cls_index: ', cls_index)
roi = torchvision.ops.roi_align(input=hmap1_s[:, cls_index, :, :].unsqueeze(0), boxes=box_list,
output_size=(128, 128))
if cls_index == 0:
roi_align_feats_0.append(roi)
# print('roi_align_feats_0.append(roi)')
elif cls_index == 1:
roi_align_feats_1.append(roi)
# print('roi_align_feats_1.append(roi)')
elif cls_index == 2:
roi_align_feats_2.append(roi)
# print('roi_align_feats_2.append(roi)')
if len(roi_align_feats_0):
roi_align_feats_0 = torch.cat(roi_align_feats_0, 0)
else:
roi_align_feats_0 = torch.zeros((1, 1, 128, 128))
if len(roi_align_feats_1):
roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
else:
roi_align_feats_1 = torch.zeros((1, 1, 128, 128))
if len(roi_align_feats_2):
roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
else:
roi_align_feats_2 = torch.zeros((1, 1, 128, 128))
# roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
# roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
roi_align_feats_0 = roi_align_feats_0.cuda()
roi_align_feats_1 = roi_align_feats_1.cuda()
roi_align_feats_2 = roi_align_feats_2.cuda()
roi_align_feats = torch.cat((roi_align_feats_0, roi_align_feats_1, roi_align_feats_2), 1)
roi_align_feats += hmap1_s
# print('roi_align_feats:', roi_align_feats.shape)
all_roi_align_feats.append(roi_align_feats)
all_roi_align_feats = torch.cat(all_roi_align_feats, 0)
pytorch使用 ROIalign 代码实例
最新推荐文章于 2024-05-16 21:51:17 发布