visual = False
if visual and cu_dt[‘pic_type’][0] == 1:
def visual_pic(project_x, project_y, pic):
import matplotlib.pyplot as plt
plt.scatter(project_x, project_y, alpha=0.6, s=2, c=‘r’)
plt.imshow(pic, alpha=1)
plt.show()
xmap = np.array([[j for i in range(640)] for j in range(480)])
ymap = np.array([[i for i in range(640)] for j in range(480)])
np_key_list = end_points[‘keypoints_idx_lists’].cpu().detach().flatten().numpy()
ori_choose = cu_dt[‘choose’].cpu().detach().flatten().numpy()
ori_fore_point_idx = cu_dt[‘fore_point_idx’].cpu().detach().flatten().numpy()
np_choose = ori_choose
xmap_mask = xmap.flatten()[np_choose]
ymap_mask = ymap.flatten()[np_choose]
ori_img = data[‘ori_img’].cpu().detach().numpy()[0]
visual_pic(ymap_mask, xmap_mask, ori_img)
def sel_keyPoints(feature_map,num_keypoints):
b,di,num=feature_map.size()
bs_keypoints_list=[]
for i in range(b):
_,idx=torch.max(feature_map[i],1)
dict_keypoints=Counter(idx.tolist())
sort_keypoints=dict(sorted(dict_keypoints.items(),key=lambda item:item[1],reverse=True))
sort_keypoints_idx=list(sort_keypoints.keys())
# keypoints_list=np.array(list(set(idx[i].tolist())))
if len(sort_keypoints_idx)>num_keypoints:
keypoints_list=sort_keypoints_idx[:num_keypoints]
else:
keypoints_list=sort_keypoints_idx.copy()
for i in range(num_keypoints - len(keypoints_list)):
if len(keypoints_list)==0:
print('wrong')
keypoints_list.append(sort_keypoints_idx[0])
bs_keypoints_list.append(keypoints_list)
return torch.LongTensor(bs_keypoints_list)
visual = False
if visual and cu_dt['pic_type'][0] == 1:
def visual_pic(project_x, project_y, pic):
import matplotlib.pyplot as plt
plt.scatter(project_x, project_y, alpha=0.6, s=2, c='r')
plt.imshow(pic, alpha=1)
plt.show()
xmap = np.array([[j for i in range(640)] for j in range(480)])
ymap = np.array([[i for i in range(640)] for j in range(480)])
np_key_list = end_points['keypoints_idx_lists'].cpu().detach().flatten().numpy()
ori_choose = cu_dt['choose'].cpu().detach().flatten().numpy()
ori_fore_point_idx = cu_dt['fore_point_idx'].cpu().detach().flatten().numpy()
np_choose = ori_choose
xmap_mask = xmap.flatten()[np_choose]
ymap_mask = ymap.flatten()[np_choose]
ori_img = data['ori_img'].cpu().detach().numpy()[0]
visual_pic(ymap_mask, xmap_mask, ori_img)
def train(multithread=True):
print("local_rank:", args.local_rank)
cudnn.benchmark = True
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
torch.cuda.set_device(args.local_rank)
if multithread:
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
)
torch.manual_seed(0)
if not args.eval_net:
train_ds = dataset_desc.Dataset('train')
if multithread:
print(config.mini_batch_size)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=config.mini_batch_size, shuffle=False,
drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
)
else:
# train_loader = torch.utils.data.DataLoader(
# train_ds, batch_size=32, shuffle=False,
# drop_last=True, pin_memory=True
# )
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=1, shuffle=False,
drop_last=True, num_workers=4
)
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
# train_loader = torch.utils.data.DataLoader(
# train_ds, batch_size=config.mini_batch_size, shuffle=False,
# drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
# )
val_ds = dataset_desc.Dataset('test')
if multithread:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
drop_last=False, num_workers=4, sampler=val_sampler)
else:
# val_loader = torch.utils.data.DataLoader(
# val_ds, batch_size=2, shuffle=False,
# drop_last=False)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
drop_last=False, num_workers=4)
# val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
# val_loader = torch.utils.data.DataLoader(
# val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
# drop_last=False, num_workers=4, sampler=val_sampler
# )
else:
test_ds = dataset_desc.Dataset('test')
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=config.test_mini_batch_size, shuffle=False,
num_workers=20
)
rndla_cfg = ConfigRandLA
if not args.eval_net:
model = FFB6D(
n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
n_kps=config.n_keypoints
)
else:
model = FFB6D(
n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
n_kps=config.n_keypoints
)
# model = FFB6D(num_obj=config.n_objects)
model = convert_syncbn_model(model)
device = torch.device('cuda:{}'.format(args.local_rank))
print('local_rank:', args.local_rank)
model.to(device)
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
opt_level = args.opt_level
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level,
)
# default value
it = -1 # for the initialize value of `LambdaLR` and `BNMomentumScheduler`
best_loss = 1e10
start_epoch = 1
# load status from checkpoint
if args.checkpoint is not None:
checkpoint_status = load_checkpoint(
model, optimizer, filename=args.checkpoint[:-8]
)
if checkpoint_status is not None:
it, start_epoch, best_loss = checkpoint_status
if args.eval_net:
assert checkpoint_status is not None, "Failed loadding model."
if not args.eval_net:
if multithread:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank,
find_unused_parameters=True
)
clr_div = 6
lr_scheduler = CyclicLR(
optimizer, base_lr=1e-5, max_lr=1e-3,
cycle_momentum=False,
step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
mode='triangular'
)
else:
lr_scheduler = None
bnm_lmbd = lambda it: max(
args.bn_momentum * args.bn_decay ** (int(it * config.mini_batch_size / args.decay_step)),
bnm_clip,
)
bnm_scheduler = pt_utils.BNMomentumScheduler(
model, bn_lambda=bnm_lmbd, last_epoch=it
)
it = max(it, 0) # for the initialize value of `trainer.train`
if args.eval_net:
model_fn = model_fn_decorator(
FocalLoss(gamma=2), OFLoss(),
args.test,
)
else:
model_fn = model_fn_decorator(
FocalLoss(gamma=2).to(device), OFLoss().to(device),
args.test,
)
checkpoint_fd = config.log_model_dir
trainer = Trainer(
model,
model_fn,
optimizer,
checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"),
best_name=os.path.join(checkpoint_fd, "FFB6D_best"),
lr_scheduler=lr_scheduler,
bnm_scheduler=bnm_scheduler,
)
if args.eval_net:
start = time.time()
val_loss, res = trainer.eval_epoch(
test_loader, is_test=True, test_pose=args.test_pose
)
end = time.time()
print("\nUse time: ", end - start, 's')
else:
trainer.train(
it, start_epoch, config.n_total_epoch, train_loader, None,
val_loader, best_loss=best_loss,
tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus,
clr_div=clr_div
)
if start_epoch == config.n_total_epoch:
_ = trainer.eval_epoch(val_loader)
if __name__ == "__main__":
args.world_size = args.gpus * args.nodes
train(multithread=True)