没有配置好cpp cu的代码,因此没有调试,但是把大致的框架给学习了一下,在这里做一下记录,以后配置后了进行调试。目前顺着是自洽的,理解可能有错误,如果发现请大家指出。
1.demo.py的结构
1)提取目标对象的分割特征和每一像素点对应的方向。
在这里面net就是Resnet18_8s,使用Resnet18_8s提取出来了seg_pred, vertex_pred,其中seg_pred表示分割出来的像素, vertex_pred表示每个分割出来的像素点的方向(理想情况下为指向真值特征/像素点的方向)。
在NetWrapper中计算loss_seg。smooth_l1_loss计算方向损失,compute_precision_recall计算精度。
主要是seg_pred和vertex_pred。
2)投票生成特征点(2d特征点)
主要是这个函数,EvalWrapper中调用了ransac_voting_layer_v3来生成特征点。传入的就是seg_pred和vertex_pred。生成的是conner_pred就是生成的特征点。
3)使用pnp计算位姿。(2d-3d特征点对应)
def demo():
net = Resnet18_8s(ver_dim=vote_num * 2, seg_dim=2)
net = NetWrapper(net).cuda()
net = DataParallel(net)
optimizer = optim.Adam(net.parameters(), lr=train_cfg['lr'])
model_dir = os.path.join(cfg.MODEL_DIR, "cat_demo")
load_model(net.module.net, optimizer, model_dir, -1)
data, points_3d, bb8_3d = read_data()
image, mask, vertex, vertex_weights, pose, corner_target = [d.unsqueeze(0).cuda() for d in data]
seg_pred, vertex_pred, loss_seg, loss_vertex, precision, recall = net(image, mask, vertex, vertex_weights)
# visualize_mask(mask)
# visualize_vertex(vertex, vertex_weights)
# visualize_hypothesis(image, seg_pred, vertex_pred, corner_target)
# visualize_voting_ellipse(image, seg_pred, vertex_pred, corner_target)
eval_net = DataParallel(EvalWrapper().cuda())
corner_pred = eval_net(seg_pred, vertex_pred).cpu().detach().numpy()[0]
camera_matrix = np.array([[572.4114, 0., 325.2611],
[0., 573.57043, 242.04899],
[0., 0., 1.]])
pose_pred = pnp(points_3d, corner_pred, camera_matrix)
projector = Projector()
bb8_2d_pred = projector.project(bb8_3d, pose_pred, 'linemod')
bb8_2d_gt = projector.project(bb8_3d, pose[0].detach().cpu().numpy(), 'linemod')
image = imagenet_to_uint8(image.detach().cpu().numpy())[0]
visualize_bounding_box(image[None, ...], bb8_2d_pred[None, None, ...], bb8_2d_gt[None, None, ...])
2.ransac_voting_layer_v3
1)做一些预处理。
2)cur_hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs) # [hn,vn,2]
direct就是方向,coords就是像素坐标。idx就是表示要参与计算假设点的索引。cur_hyp_pts就是根据idx中的i*2,2*i+1索引,对应的coords和direct,也就是每一个点就有一个点位置和一个方向,计算交叉点就得到了cur_hpy_pts。这个可以看cu文件,容易看懂。
3)ransac_voting.voting_for_hypothesis(direct, coords, cur_hyp_pts, cur_inlier, inlier_thresh) # [hn,vn,tn]
就是点到假设点之间方向与 点的方向 之间的 夹角,如果夹角满足阈值,则认为是内点。这个可以去看cu文件,感觉cu文件很好理解。
4)all_win_pts = torch.matmul(b_inv(ATA), torch.unsqueeze(ATb, 2))
这个也比较好理解,比如对于第k个特征点,有m个点和对应的m个方向对该特征点进行投票,这时候就可以使用如下公式:
normal*(coords - featurePoint) = 0;
变成
normal.t()*normal = ATA normal.t()*[ normal*coords]表示的是ATB
求解的就是特征点位置。
看这几行代码:
normal = torch.zeros_like(direct) # [tn,vn,2]
normal[:, :, 0] = direct[:, :, 1]
normal[:, :, 1] = -direct[:, :, 0]
法向就是表示的是与点方向垂直的 方向。
normal = normal*torch.unsqueeze(all_inlier, 2) 这个表示的是normal乘上的是一个1或者0的之时数组或者矩阵,就是表示该点和方向是否给 该特征点投票。或者可以理解为该像素点点和该像素点对应的方向 是否是 该特征点的内点。
def ransac_voting_layer_v3(mask, vertex, round_hyp_num, inlier_thresh=0.999, confidence=0.99, max_iter=20,
min_num=5, max_num=30000):
'''
:param mask: [b,h,w]
:param vertex: [b,h,w,vn,2]
:param round_hyp_num:
:param inlier_thresh:
:return: [b,vn,2]
'''
b, h, w, vn, _ = vertex.shape
batch_win_pts = []
for bi in range(b):
hyp_num = 0
cur_mask = (mask[bi]).byte()
foreground_num = torch.sum(cur_mask)
# if too few points, just skip it
if foreground_num < min_num:
win_pts = torch.zeros([1, vn, 2], dtype=torch.float32, device=mask.device)
batch_win_pts.append(win_pts) # [1,vn,2]
continue
# if too many inliers, we randomly down sample it
if foreground_num > max_num:
selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
selected_mask = (selection < (max_num / foreground_num.float())).byte()
cur_mask *= selected_mask
coords = torch.nonzero(cur_mask).float() # [tn,2]
coords = coords[:, [1, 0]]
direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3)) # [tn,vn,2]
direct = direct.view([coords.shape[0], vn, 2])
tn = coords.shape[0]
idxs = torch.zeros([round_hyp_num, vn, 2], dtype=torch.int32, device=mask.device).random_(0, direct.shape[0])
all_win_ratio = torch.zeros([vn], dtype=torch.float32, device=mask.device)
all_win_pts = torch.zeros([vn, 2], dtype=torch.float32, device=mask.device)
cur_iter = 0
while True:
# generate hypothesis
cur_hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs) # [hn,vn,2]
# voting for hypothesis
cur_inlier = torch.zeros([round_hyp_num, vn, tn], dtype=torch.uint8, device=mask.device)
ransac_voting.voting_for_hypothesis(direct, coords, cur_hyp_pts, cur_inlier, inlier_thresh) # [hn,vn,tn]
# find max
cur_inlier_counts = torch.sum(cur_inlier, 2) # [hn,vn]
cur_win_counts, cur_win_idx = torch.max(cur_inlier_counts, 0) # [vn]
cur_win_pts = cur_hyp_pts[cur_win_idx, torch.arange(vn)]
cur_win_ratio = cur_win_counts.float() / tn
# update best point
larger_mask = all_win_ratio < cur_win_ratio
all_win_pts[larger_mask, :] = cur_win_pts[larger_mask, :]
all_win_ratio[larger_mask] = cur_win_ratio[larger_mask]
# check confidence
hyp_num += round_hyp_num
cur_iter += 1
cur_min_ratio = torch.min(all_win_ratio)
if (1 - (1 - cur_min_ratio ** 2) ** hyp_num) > confidence or cur_iter > max_iter:
break
# compute mean intersection again
normal = torch.zeros_like(direct) # [tn,vn,2]
normal[:, :, 0] = direct[:, :, 1]
normal[:, :, 1] = -direct[:, :, 0]
all_inlier = torch.zeros([1, vn, tn], dtype=torch.uint8, device=mask.device)
all_win_pts = torch.unsqueeze(all_win_pts, 0) # [1,vn,2]
ransac_voting.voting_for_hypothesis(direct, coords, all_win_pts, all_inlier, inlier_thresh) # [1,vn,tn]
# coords [tn,2] normal [vn,tn,2]
all_inlier = torch.squeeze(all_inlier.float(), 0) # [vn,tn]
normal = normal.permute(1, 0, 2) # [vn,tn,2]
normal = normal*torch.unsqueeze(all_inlier, 2) # [vn,tn,2] outlier is all zero
b = torch.sum(normal*torch.unsqueeze(coords, 0), 2) # [vn,tn]
ATA = torch.matmul(normal.permute(0, 2, 1), normal) # [vn,2,2]
ATb = torch.sum(normal*torch.unsqueeze(b, 2), 1) # [vn,2]
# try:
all_win_pts = torch.matmul(b_inv(ATA), torch.unsqueeze(ATb, 2)) # [vn,2,1]
# except:
# __import__('ipdb').set_trace()
batch_win_pts.append(all_win_pts[None,:,:, 0])
batch_win_pts = torch.cat(batch_win_pts)
return batch_win_pts
3.和论文中公式(3)和(4)一致的代码,生成均值特征点和方差
前面部分和 ransac_voting_layer_v3 代码相同,后边部分代码也相对好理解,不做解释。
def estimate_voting_distribution_with_mean(mask, vertex, mean, round_hyp_num=256, min_hyp_num=4096, topk=128, inlier_thresh=0.99, min_num=5, max_num=30000, output_hyp=False):
b, h, w, vn, _ = vertex.shape
all_hyp_pts, all_inlier_ratio = [], []
for bi in range(b):
k = 0
cur_mask = mask[bi] == k + 1
foreground = torch.sum(cur_mask)
# if too few points, just skip it
if foreground < min_num:
cur_hyp_pts = torch.zeros([1, min_hyp_num, vn, 2], dtype=torch.float32, device=mask.device).float()
all_hyp_pts.append(cur_hyp_pts) # [1,vn,2]
cur_inlier_ratio = torch.ones([1, min_hyp_num, vn], dtype=torch.int64, device=mask.device).float()
all_inlier_ratio.append(cur_inlier_ratio)
continue
# if too many inliers, we randomly down sample it
if foreground > max_num:
selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
selected_mask = (selection < (max_num / foreground.float()))
cur_mask *= selected_mask
foreground = torch.sum(cur_mask)
coords = torch.nonzero(cur_mask).float() # [tn,2]
coords = coords[:, [1, 0]]
direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3)) # [tn,vn,2]
direct = direct.view([coords.shape[0], vn, 2])
tn = coords.shape[0]
round_num = np.ceil(min_hyp_num/round_hyp_num)
cur_hyp_pts = []
cur_inlier_ratio = []
for round_idx in range(int(round_num)):
idxs = torch.zeros([round_hyp_num, vn, 2], dtype=torch.int32, device=mask.device).random_(0, direct.shape[0])
# generate hypothesis
hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs) # [hn,vn,2]
# voting for hypothesis
inlier = torch.zeros([round_hyp_num, vn, tn], dtype=torch.uint8, device=mask.device)
ransac_voting.voting_for_hypothesis(direct, coords, hyp_pts, inlier, inlier_thresh) # [hn,vn,tn]
inlier_ratio = torch.sum(inlier, 2) # [hn,vn]
inlier_ratio = inlier_ratio.float()/foreground.float() # ratio
cur_hyp_pts.append(hyp_pts)
cur_inlier_ratio.append(inlier_ratio)
cur_hyp_pts = torch.cat(cur_hyp_pts, 0)
cur_inlier_ratio = torch.cat(cur_inlier_ratio, 0)
all_hyp_pts.append(torch.unsqueeze(cur_hyp_pts, 0))
all_inlier_ratio.append(torch.unsqueeze(cur_inlier_ratio, 0))
all_hyp_pts = torch.cat(all_hyp_pts, 0) # b,hn,vn,2
all_inlier_ratio = torch.cat(all_inlier_ratio, 0) # b,hn,vn
# raw_hyp_pts=all_hyp_pts.permute(0,2,1,3).clone()
# raw_hyp_ratio=all_inlier_ratio.permute(0,2,1).clone()
all_hyp_pts = all_hyp_pts.permute(0, 2, 1, 3) # b,vn,hn,2
all_inlier_ratio = all_inlier_ratio.permute(0, 2, 1) # b,vn,hn
thresh = torch.max(all_inlier_ratio, 2)[0]-0.1 # b,vn
all_inlier_ratio[all_inlier_ratio < torch.unsqueeze(thresh, 2)] = 0.0
diff_pts = all_hyp_pts-torch.unsqueeze(mean, 2) # b,vn,hn,2
weighted_diff_pts = diff_pts * torch.unsqueeze(all_inlier_ratio, 3)
cov = torch.matmul(diff_pts.transpose(2, 3), weighted_diff_pts) # b,vn,2,2
cov /= torch.unsqueeze(torch.unsqueeze(torch.sum(all_inlier_ratio, 2), 2), 3)+1e-3 # b,vn,2,2
# if output_hyp:
# return mean,cov,all_hyp_pts,all_inlier_ratio,raw_hyp_pts,raw_hyp_ratio
return mean, cov