pytorch使用 ROIalign 代码实例

batch_size = x.shape[0]
all_roi_align_feats = []
for i in range(batch_size):
    hmap1_s = out_hmap_1[i, :, :, :]
    hmap1_s = hmap1_s.unsqueeze(0)
    # print('hmap1_s: ', hmap1_s.shape)
    regs1_s = regs[i, :, :, :]
    regs1_s = regs1_s.unsqueeze(0)

    dets = ctdet_decode(hmap1_s, regs1_s, K=100)
    # print('dets: ', dets.shape)

    roi_align_feats_0 = []
    roi_align_feats_1 = []
    roi_align_feats_2 = []
    for index, box in enumerate(dets):
        box_list = []
        # print(dets[0, index, :4])
        box_list.append(dets[:, index, :4])
        cls_index = int(dets[:, index, -1].cpu().detach().numpy()[0])
        # print('cls_index: ', cls_index)
        roi = torchvision.ops.roi_align(input=hmap1_s[:, cls_index, :, :].unsqueeze(0), boxes=box_list,
                                        output_size=(128, 128))

        if cls_index == 0:
            roi_align_feats_0.append(roi)
            # print('roi_align_feats_0.append(roi)')
        elif cls_index == 1:
            roi_align_feats_1.append(roi)
            # print('roi_align_feats_1.append(roi)')
        elif cls_index == 2:
            roi_align_feats_2.append(roi)
            # print('roi_align_feats_2.append(roi)')

    if len(roi_align_feats_0):
        roi_align_feats_0 = torch.cat(roi_align_feats_0, 0)
    else:
        roi_align_feats_0 = torch.zeros((1, 1, 128, 128))

    if len(roi_align_feats_1):
        roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
    else:
        roi_align_feats_1 = torch.zeros((1, 1, 128, 128))

    if len(roi_align_feats_2):
        roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
    else:
        roi_align_feats_2 = torch.zeros((1, 1, 128, 128))

    # roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
    # roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
    roi_align_feats_0 = roi_align_feats_0.cuda()
    roi_align_feats_1 = roi_align_feats_1.cuda()
    roi_align_feats_2 = roi_align_feats_2.cuda()
    roi_align_feats = torch.cat((roi_align_feats_0, roi_align_feats_1, roi_align_feats_2), 1)

    roi_align_feats += hmap1_s
    # print('roi_align_feats:', roi_align_feats.shape)
    all_roi_align_feats.append(roi_align_feats)

all_roi_align_feats = torch.cat(all_roi_align_feats, 0)

​

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值