【无标题】

def TCA_inference(opt): #inference
    model = TCANet('inference', opt)
    model = torch.nn.DataParallel(model).cuda()
    checkpoint = torch.load(os.path.join(opt["checkpoint_path"], "TCA_best.pth.tar"))
    try:
        model.load_state_dict(checkpoint['state_dict'])
    except Exception as e:
        print("{}  continue?".format(e))
        model.module.load_state_dict(checkpoint['model_state'])
    model.eval()
    test_loader = torch.utils.data.DataLoader(TCADataSet(opt, subset=opt["inference_dataset"]),
                                              batch_size=1, shuffle=False,
                                              num_workers=opt["num_workers"], pin_memory=True, drop_last=False)
    total_iter = len(test_loader)
    log_period = 20
    time_start = time.time()
    max_proposals = 200 // 2
    gpu_count = torch.cuda.device_count()

    with torch.no_grad():
        for n_iter, meta in tqdm(enumerate(test_loader), total=len(test_loader)):
            features = meta["features"]
            proposals = meta["proposals"]
            video_second = meta["video_duration"]
            temporal_mask = meta['temporal_mask']
            score = meta["score"]
            video_name = test_loader.dataset.video_list[n_iter]
            score = score.squeeze(0)
            features = features.cuda()
            proposals = proposals.cuda()
            video_second = video_second.cuda()
            temporal_mask = temporal_mask.cuda()
            iter_num = math.ceil(proposals.size(1) / max_proposals)
            proposals2_list = []
            iou2_list = []
            proposals3_list = []
            iou3_list = []
            proposals4_list = []
            iou4_list = []
            for p_num in range(iter_num):
                input = (features, video_second, proposals[:, p_num * max_proposals:(p_num + 1) * max_proposals, :], None, temporal_mask)
                preds_meta = model(input)
                proposals2_list.append(preds_meta["proposals1"].cpu())
                iou2_list.append(preds_meta["iou1"].cpu())
                proposals3_list.append(preds_meta["proposals2"].cpu())
                iou3_list.append(preds_meta["iou2"].cpu())
                proposals4_list.append(preds_meta["proposals3"].cpu())
                iou4_list.append(preds_meta["iou3"].cpu())
            all_proposals2 = torch.cat(proposals2_list, dim=0)
            all_ious2 = torch.cat(iou2_list, dim=0).numpy()
            all_proposals3 = torch.cat(proposals3_list, dim=0)
            all_ious3 = torch.cat(iou3_list, dim=0).numpy()
            all_proposals4 = torch.cat(proposals4_list, dim=0)
            all_ious4 = torch.cat(iou4_list, dim=0).numpy()
            all_proposals = all_proposals4
            all_ious = all_ious4
            video_duration = meta["video_duration"].item()

            proposals = all_proposals.numpy() / video_duration
            proposals4 = all_proposals4.numpy() / video_duration
            proposals3 = all_proposals3.numpy() / video_duration
            proposals2 = all_proposals2.numpy() / video_duration
            score = score.numpy()
            # all_ious = all_ious.numpy()
            #########################################################################
            if len(score.shape) > 1:
                xmin_score = score[:, 1]
                xmax_score = score[:, 2]
                score = score[:, 0]
            else:
                xmin_score = score
                xmax_score = score
            new_score = score * all_ious
            # new_props = np.stack([proposals[:, 0], proposals[:, 1], new_score, score, all_ious], axis=1)
            new_props = np.stack(
                [proposals4[:, 0], proposals4[:, 1], proposals3[:, 0], proposals3[:, 1], proposals2[:, 0],
                 proposals2[:, 1], new_score, score, xmin_score, xmax_score,
                 all_ious, all_ious3, all_ious2], axis=1)

            col_name = ["xmin", "xmax", "xmin3", "xmax3", "xmin2", "xmax2", "score", "ori_score", "xmin_score",
                        "xmax_score", "pred_iou", "preds_iou3", "preds_iou2"]
            new_df = pd.DataFrame(new_props, columns=col_name)
            new_df.to_csv(
                os.path.join(opt["checkpoint_path"], "TCA_results", video_name + ".csv"),
                index=False)
            current_iter = n_iter
            if current_iter % log_period == 0 and False:
                derta = time.time() - time_start
                avg_time = derta / (current_iter + 1)
                end_time = time.time() + avg_time * (total_iter - current_iter + 1)
                print(
                    "TCA inference iter %d/%d   until:%s" % (
                        current_iter + 1, total_iter,
                        time.asctime(time.localtime(end_time))
                    ))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值