2021-09-03

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述 visual = False
if visual and cu_dt[‘pic_type’][0] == 1:
def visual_pic(project_x, project_y, pic):
import matplotlib.pyplot as plt
plt.scatter(project_x, project_y, alpha=0.6, s=2, c=‘r’)
plt.imshow(pic, alpha=1)
plt.show()
xmap = np.array([[j for i in range(640)] for j in range(480)])
ymap = np.array([[i for i in range(640)] for j in range(480)])
np_key_list = end_points[‘keypoints_idx_lists’].cpu().detach().flatten().numpy()
ori_choose = cu_dt[‘choose’].cpu().detach().flatten().numpy()
ori_fore_point_idx = cu_dt[‘fore_point_idx’].cpu().detach().flatten().numpy()
np_choose = ori_choose
xmap_mask = xmap.flatten()[np_choose]
ymap_mask = ymap.flatten()[np_choose]
ori_img = data[‘ori_img’].cpu().detach().numpy()[0]
visual_pic(ymap_mask, xmap_mask, ori_img)

def sel_keyPoints(feature_map,num_keypoints):
    b,di,num=feature_map.size()
    bs_keypoints_list=[]
    for i in range(b):
        _,idx=torch.max(feature_map[i],1)
        dict_keypoints=Counter(idx.tolist())
        sort_keypoints=dict(sorted(dict_keypoints.items(),key=lambda item:item[1],reverse=True))
        sort_keypoints_idx=list(sort_keypoints.keys())
        # keypoints_list=np.array(list(set(idx[i].tolist())))
        if len(sort_keypoints_idx)>num_keypoints:
            keypoints_list=sort_keypoints_idx[:num_keypoints]
        else:
            keypoints_list=sort_keypoints_idx.copy()
            for i in range(num_keypoints - len(keypoints_list)):
                if len(keypoints_list)==0:
                    print('wrong')
                keypoints_list.append(sort_keypoints_idx[0])
        bs_keypoints_list.append(keypoints_list)
    return torch.LongTensor(bs_keypoints_list)
    
    visual = False
    if visual and cu_dt['pic_type'][0] == 1:
        def visual_pic(project_x, project_y, pic):
            import matplotlib.pyplot as plt
            plt.scatter(project_x, project_y, alpha=0.6, s=2, c='r')
            plt.imshow(pic, alpha=1)
            plt.show()
        xmap = np.array([[j for i in range(640)] for j in range(480)])
        ymap = np.array([[i for i in range(640)] for j in range(480)])
        np_key_list = end_points['keypoints_idx_lists'].cpu().detach().flatten().numpy()
        ori_choose = cu_dt['choose'].cpu().detach().flatten().numpy()
        ori_fore_point_idx = cu_dt['fore_point_idx'].cpu().detach().flatten().numpy()
        np_choose = ori_choose
        xmap_mask = xmap.flatten()[np_choose]
        ymap_mask = ymap.flatten()[np_choose]
        ori_img = data['ori_img'].cpu().detach().numpy()[0]
        visual_pic(ymap_mask, xmap_mask, ori_img)
def train(multithread=True):
    print("local_rank:", args.local_rank)
    cudnn.benchmark = True
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)
    torch.cuda.set_device(args.local_rank)
    if multithread:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://',
        )
    torch.manual_seed(0)

    if not args.eval_net:
        train_ds = dataset_desc.Dataset('train')
        if multithread:
            print(config.mini_batch_size)
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
            train_loader = torch.utils.data.DataLoader(
                train_ds, batch_size=config.mini_batch_size, shuffle=False,
                drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
            )
        else:
            # train_loader = torch.utils.data.DataLoader(
            #     train_ds, batch_size=32, shuffle=False,
            #     drop_last=True, pin_memory=True
            # )
            train_loader = torch.utils.data.DataLoader(
                train_ds, batch_size=1, shuffle=False,
                drop_last=True, num_workers=4
            )
        # train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
        # train_loader = torch.utils.data.DataLoader(
        #     train_ds, batch_size=config.mini_batch_size, shuffle=False,
        #     drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
        # )

        val_ds = dataset_desc.Dataset('test')
        if multithread:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
            val_loader = torch.utils.data.DataLoader(
                val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
                drop_last=False, num_workers=4, sampler=val_sampler)
        else:
            # val_loader = torch.utils.data.DataLoader(
            #     val_ds, batch_size=2, shuffle=False,
            #     drop_last=False)
            val_loader = torch.utils.data.DataLoader(
                val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
                drop_last=False, num_workers=4)
        # val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
        # val_loader = torch.utils.data.DataLoader(
        #     val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
        #     drop_last=False, num_workers=4, sampler=val_sampler
        # )
    else:
        test_ds = dataset_desc.Dataset('test')
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=config.test_mini_batch_size, shuffle=False,
            num_workers=20
        )

    rndla_cfg = ConfigRandLA
    if not args.eval_net:
        model = FFB6D(
            n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
            n_kps=config.n_keypoints
        )
    else:
        model = FFB6D(
            n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
            n_kps=config.n_keypoints
        )
    # model = FFB6D(num_obj=config.n_objects)
    model = convert_syncbn_model(model)
    device = torch.device('cuda:{}'.format(args.local_rank))
    print('local_rank:', args.local_rank)
    model.to(device)
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
    opt_level = args.opt_level
    model, optimizer = amp.initialize(
        model, optimizer, opt_level=opt_level,
    )

    # default value
    it = -1  # for the initialize value of `LambdaLR` and `BNMomentumScheduler`
    best_loss = 1e10
    start_epoch = 1

    # load status from checkpoint
    if args.checkpoint is not None:
        checkpoint_status = load_checkpoint(
            model, optimizer, filename=args.checkpoint[:-8]
        )
        if checkpoint_status is not None:
            it, start_epoch, best_loss = checkpoint_status
        if args.eval_net:
            assert checkpoint_status is not None, "Failed loadding model."

    if not args.eval_net:
        if multithread:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank], output_device=args.local_rank,
                find_unused_parameters=True
            )
        clr_div = 6
        lr_scheduler = CyclicLR(
            optimizer, base_lr=1e-5, max_lr=1e-3,
            cycle_momentum=False,
            step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            mode='triangular'
        )
    else:
        lr_scheduler = None

    bnm_lmbd = lambda it: max(
        args.bn_momentum * args.bn_decay ** (int(it * config.mini_batch_size / args.decay_step)),
        bnm_clip,
    )
    bnm_scheduler = pt_utils.BNMomentumScheduler(
        model, bn_lambda=bnm_lmbd, last_epoch=it
    )

    it = max(it, 0)  # for the initialize value of `trainer.train`

    if args.eval_net:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2), OFLoss(),
            args.test,
        )
    else:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2).to(device), OFLoss().to(device),
            args.test,
        )

    checkpoint_fd = config.log_model_dir

    trainer = Trainer(
        model,
        model_fn,
        optimizer,
        checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"),
        best_name=os.path.join(checkpoint_fd, "FFB6D_best"),
        lr_scheduler=lr_scheduler,
        bnm_scheduler=bnm_scheduler,
    )

    if args.eval_net:
        start = time.time()
        val_loss, res = trainer.eval_epoch(
            test_loader, is_test=True, test_pose=args.test_pose
        )
        end = time.time()
        print("\nUse time: ", end - start, 's')
    else:
        trainer.train(
            it, start_epoch, config.n_total_epoch, train_loader, None,
            val_loader, best_loss=best_loss,
            tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus,
            clr_div=clr_div
        )

        if start_epoch == config.n_total_epoch:
            _ = trainer.eval_epoch(val_loader)


if __name__ == "__main__":
    args.world_size = args.gpus * args.nodes
    train(multithread=True)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值