点云配准的综述和代码记录

目录

1.综述

​编辑

1.Introduction

2.An overview of the 3D point cloud registration methods

2.1 Feature-based registration methods

2.2. Cloud-to-cloud fine registration methods

2.2.1. ICP-derived methods

2.2.2. Probabilistic methods

2.代码

3.ref


1.综述

思路:为了减少标注量,做一些3D的配准,这样只需要标注一帧图像,然后通过配准找到两帧图像之间的相机的位姿变换矩阵,然后原始标注数据*位姿变换矩阵 = new 一帧的标注数据

paper: https://arxiv.org/ftp/arxiv/papers/2302/2302.07184.pdf

1.Introduction

传统的点云配准方法通常用于同源和合作数据集的配准,例如两个地面激光扫描点云之间的配准或两个空中激光扫描点云之间的配准。然而,许多应用需要配准不同源、低重叠或度量不准确的数据集,这些任务非常具有挑战性,传统的配准算法往往难以成功

点云配准(point cloud registration)的两个主要步骤:

  1. 粗配准(coarse registration):使用稀疏特征点对应关系建立两个数据集之间的初始几何变换。这个步骤通常使用一些特征提取和匹配算法来寻找两个点云之间的相似性,例如SIFT、SURF、ORB等算法。这个步骤的目的是将两个点云的初始位置和姿态对齐,为后续的精细配准提供一个良好的初始估计。

  2. 精配准(fine registration):使用更多和更密集的对应关系来进一步优化几何变换,也称为点云对点云(cloud-to-cloud,C2C)配准。这个步骤通常使用一些迭代最近点(Iterative Closest Point,ICP)算法或其变种来寻找两个点云之间的最优变换。这个步骤的目的是尽可能地减小两个点云之间的距离和误差,使它们在几何上完全对齐。

  3. feature-based coarse registration and the C2C fine registration methods,

2.An overview of the 3D point cloud registration methods

2.1 Feature-based registration methods

在稀疏三维对应提取中,有几种方法可以用于提取点、线和面等特征。

点云方法:这种方法首先计算每个点及其局部表面内的相邻点的简化点特征直方图(SPFH),然后将这些特征线性组合成33维特征向量。其中一种方法是SpinNet,它是一种基于学习的点特征描述符,通过将局部表面块转换为旋转不变的圆柱体积,然后使用3D圆柱卷积网络(3DCCN)提取描述符。SpinNet在一些室内和室外基准数据集中表现出色。

线和面方法:这种方法首先检测位于两个平面的边界线和交线上的特征点作为关键点,然后将这些关键点连接到它们的局部邻域以提取几何特征,如偏转角和点距离。然后根据这些信息对这些点进行排序以形成符合点集,并通过改进的三次B样条曲线拟合算法进行拟合。另外,一些方法直接利用提取的三维线和平面作为特征。

基于四点共面集(4PCS)的方法:这种方法不是逐点进行对应搜索,而是匹配符合特定拓扑关系的点云。它的基本单元是提取四个共面点作为点对应匹配,因为在一些拓扑约束下,一组共面点更有可能抵抗异常值。

在从三维对应关系中估计变换参数方面,有几种方法可以使用:

一致性最大化:这种方法是一种更广义的具有随机性质的方法框架。它在解空间中测试多个假设,并使用每个假设产生符合该变换的内点。解决方案是找到具有最大内点集的假设。其中主要的方法是RANSAC及其变体。

M-估计器:这是一类方法,根据嘈杂的观测结果制定鲁棒的代价函数。它根据拟合残差将所有观测(即三维对应关系)包含在自适应代价函数中。其中一些方法包括Geman-McClure函数、截断最小二乘函数、Huber函数和ℓ1-范数函数等。

对应分组:这种方法将符合刚性变换的三维对应关系分组,用于推导最终的变换参数。经典方法中的一种是基于综合评估的方法,最近的一些基于学习的方法使用可学习特征编码上下文信息,以更准确、更稳健地描述对应关系的潜力。

2.2. Cloud-to-cloud fine registration methods

在云对云(Cloud-to-Cloud,C2C)精细配准方法中,通过优化C2C距离来迭代更新变换参数,从而实现点云之间的密集对应关系。C2C距离在每个点级别上计算,考虑了点云中的所有点,因此通常被视为配准流程的最后一步,以提供最准确的结果。基于C2C距离的方法可以分为两类:ICP派生方法和概率方法。

2.2.1. ICP-derived methods

ICP派生方法是最常见的C2C精细配准方法之一。标准ICP使用chamfer距离作为C2C距离,它计算一个点云中的点到另一个点云中最近点的距离之和。ICP使用期望最大化(EM)算法来最小化平方距离的总和,通过建立每个点的chamfer距离并根据更新的距离估计配准参数。ICP算法的变体致力于改进迭代步骤,包括改进误差度量、增强对应关系搜索策略和设计更鲁棒的目标函数。

改进误差度量的方法包括点对平面ICP、对称ICP和广义ICP。点对平面ICP将表面法线纳入误差度量中,通过将点云中的点与另一个点云中的点定义的平面对齐来优化配准。对称ICP提出了一种对称误差度量,将点对应关系中两个点的表面法线结合起来,以提高鲁棒性。广义ICP将标准ICP和点对平面ICP统一到一个概率框架中,使用协方差矩阵对两个点的局部表面进行建模。

改进对应关系搜索策略的方法包括使用随机直线连接点云中的点来生成对应关系,以提供正确对应关系的线索,并排除不正确的对应关系。

设计更鲁棒的目标函数的方法包括在ICP中添加鲁棒核函数,如Welsch函数,通过给异常值赋予较小的权重来提高鲁棒性。

除了上述方法之外,还有其他ICP的变体,它们综合考虑了精度、效率和鲁棒性的改进。这些方法根据特定任务进行适应,例如在里程计应用中使用的KISS-ICP方法。

总的来说,C2C精细配准方法通过优化C2C距离来实现点云之间的密集对应关系,ICP派生方法是其中最常用的方法之一,并且已经有多种改进的变体被提出,以提高配准的准确性、鲁棒性和效率。

2.2.2. Probabilistic methods

概率方法在点云配准中的思想是将点云建模为概率分布,并通过最小化两个数据分布之间的差异来解决配准问题。常用的方法是使用高斯混合模型(GMM)对点云进行建模,其中点云可以被视为GMM模型的样本或者通过将每个点的位置作为质心来表示GMM本身。通过最大似然估计器(MLE)来最小化由GMM表示的两个点云之间的差异。

大部分概率方法使用相同的GMM对两个点云进行建模,相比于使用单独的GMM,这种方法更加敏感。概率方法相比于基于ICP的方法,通过构建更密集的加权对应关系,可以更加鲁棒地进行配准,但同时需要更多的内存和计算资源。因此,最近的研究主要关注于通过技术手段提高计算效率,例如使用高效的数据结构(如GMM-Tree)、快速和近似方法(如快速高斯变换和排列滤波器)以及GPU加速。

此外,基于学习的方法近年来也得到广泛应用。这些方法不再将每个点建模为高斯混合成分的质心,而是将点云分组,并将每个组建模为高斯混合成分。这些方法包括深度GMM和深度分层GMM等,相比传统方法,它们展现出有希望的配准性能和速度。

其中一种广泛使用的概率方法是一致点漂移(Coherent Point Drift,CPD),它可以用于刚性和非刚性点云配准。CPD将点云配准问题视为概率密度估计问题,其中一个点云被假设为GMM模型,匹配点云则是变换的样本,通过最大似然估计来进行配准。CPD使用近似方法来加速计算,将计算复杂度降低到线性时间。

总的来说,概率方法在点云配准中通过建模点云为概率分布,通过最小化数据分布之间的差异来解决配准问题。最近的研究关注于提高计算效率,包括使用高效的数据结构、快速和近似方法,以及基于学习的方法。其中CPD是一种常用的概率方法,可以用于刚性和非刚性点云配准。

2.代码

记录一下之前使用icp算法,sc2算法,ppf算法,mac算法,ransac算法的点云配准的代码。

import open3d as o3d
import numpy as np
import cv2
import time
import torch
import igraph
import  math
from scipy.spatial.transform import  Rotation
def calculate_inv(RT):
    R = RT[:3, :3]
    T = RT[:3, 3]
    T_inv = -np.matmul(R.transpose(), T)
    inv_RT = np.eye(4)
    inv_RT[:3, :3] = R.transpose()
    inv_RT[:3, 3] = T_inv
    return inv_RT


def rot2euler(rot):
    r = Rotation.from_matrix(rot)
    rotation = r.as_euler('XYZ')
    return rotation



def ransac_icp(source, target):
    source_array = np.array(source.points).astype(np.float32).reshape(1, -1, 3)
    source_fpfh = o3d.pipelines.registration.compute_fpfh_feature(source,
                                                                  o3d.geometry.KDTreeSearchParamHybrid(radius=0.5,
                                                                                                       max_nn=100))
    target_fpfh = o3d.pipelines.registration.compute_fpfh_feature(target,
                                                                  o3d.geometry.KDTreeSearchParamHybrid(radius=0.5,
                                                                                                       max_nn=100))
    result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
        source, target, source_fpfh, target_fpfh, True,
        0.01,
        o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
        3,
        [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.1),
         o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.01)
         ], o3d.pipelines.registration.RANSACConvergenceCriteria(100000,  # 两个参数影响时间
                                                                 0.9999))
    trans_init_icp = result_ransac.transformation
    result_ransac_icp = o3d.pipelines.registration.registration_icp(
        source, target, 0.02, trans_init_icp,
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(5000))
    final_poses = result_ransac_icp.transformation
    model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    o3d.visualization.draw_geometries([model_ts_pcd, target])
    return final_poses


def opencv_ppf_icp(source, target):
    source_down = source
    target_down = target
    source_array = np.array(source.points).astype(np.float32).reshape(1, -1, 3)
    model_array = np.hstack((np.array(source_down.points), np.array(source_down.normals))).astype(np.float32)
    detector = cv2.ppf_match_3d_PPF3DDetector(0.025, 0.05)
    detector.trainModel(model_array)
    scene_array = np.hstack((np.array(target_down.points), np.array(target_down.normals))).astype(
        np.float32)
    result_ppfs = detector.match(scene_array, 0.025, 0.05)  # list
    result_ppf = result_ppfs[0]
    print("numVotes:", result_ppf.numVotes)
    print("residual:", result_ppf.residual)
    ts_init = result_ppf.pose
    result_icp = open3d_icp(source, target, ts_init, threshold=0.02, max_iteration=5000)
    final_poses = result_icp.transformation
    model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    o3d.visualization.draw_geometries([model_ts_pcd, target])
    return final_poses


def open3d_icp(source, target, trans_init, threshold, max_iteration):
    t1 = time.time()  # 2.5, 点数多效果好但时间慢,结果似乎取决于模型的数量,应该是两者同比例下采样,场景bbox大点好
    reg_p2p = o3d.pipelines.registration.registration_icp(
        source, target, threshold, trans_init,
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))  # iterations (30 by default)
    print("*****icp的时间:", time.time() - t1)
    print('*****icp的transformation:', reg_p2p.transformation)
    print('*****icp的.fitness:', reg_p2p.fitness)
    return reg_p2p


def extract_fpfh_features(keypts, downsample):
    keypts.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=downsample * 2, max_nn=30))
    features = o3d.pipelines.registration.compute_fpfh_feature(keypts, o3d.geometry.KDTreeSearchParamHybrid(
        radius=downsample * 5, max_nn=100))
    features = np.array(features.data).T
    features = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)
    return features


def rigid_transform_3d(A, B, weights=None, weight_threshold=0):
    """
    Input:
        - A:       [bs, num_corr, 3], source point cloud
        - B:       [bs, num_corr, 3], target point cloud
        - weights: [bs, num_corr]     weight for each correspondence
        - weight_threshold: float,    clips points with weight below threshold
    Output:
        - R, t
    """
    bs = A.shape[0]
    if weights is None:
        weights = torch.ones_like(A[:, :, 0])
    weights[weights < weight_threshold] = 0
    # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)

    # find mean of point cloud
    centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)
    centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    # construct weight covariance matrix
    Weight = torch.diag_embed(weights)  # 升维度,然后变为对角阵
    H = Am.permute(0, 2, 1) @ Weight @ Bm  # permute : tensor中的每一块做转置

    # find rotation
    U, S, Vt = torch.svd(H.cpu())
    U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device)
    delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
    eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device)
    eye[:, -1, -1] = delta_UV
    R = Vt @ eye @ U.permute(0, 2, 1)
    t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1)
    # warp_A = transform(A, integrate_trans(R,t))
    # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean()
    return integrate_trans(R, t)


def integrate_trans(R, t):
    """
    Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry.
    Input
        - R: [3, 3] or [bs, 3, 3], rotation matrix
        - t: [3, 1] or [bs, 3, 1], translation matrix
    Output
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    """
    if len(R.shape) == 3:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device)
        else:
            trans = np.eye(4)[None]
        trans[:, :3, :3] = R
        trans[:, :3, 3:4] = t.view([-1, 3, 1])
    else:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4).to(R.device)
        else:
            trans = np.eye(4)
        trans[:3, :3] = R
        trans[:3, 3:4] = t
    return trans


def mac_fpfh(src_pcd, tgt_pcd,visual=True):
    """点对越多,整体效果就越好,因为可能变多了嘛"""
    print('--------------------------1.提取特征FPFH--------------------------')
    t0 = time.time()
    # src_kpts = src_pcd.farthest_point_down_sample(500)  # 3-500 points,模型多点,场景可少点
    # tgt_kpts = tgt_pcd.farthest_point_down_sample(500)
    src_kpts = src_pcd.voxel_down_sample(0.007)  # 3-500 points,模型多点,场景可少点
    tgt_kpts = tgt_pcd.voxel_down_sample(0.006)
    print("src 数量:", len(src_kpts.points))
    print("tgt 数量:", len(tgt_kpts.points))
    src_desc = extract_fpfh_features(src_kpts, 0.05)
    tgt_desc = extract_fpfh_features(tgt_kpts, 0.05)
    # o3d.visualization.draw_geometries([o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_desc))])
    # o3d.visualization.draw_geometries([o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tgt_desc))])
    distance = np.sqrt(2 - 2 * (src_desc @ tgt_desc.T) + 1e-6)
    source_idx = np.argmin(distance, axis=1)  # for each row save the index of minimun

    print('--------------------------2.特征匹配------------------------')
    corr = np.concatenate([np.arange(source_idx.shape[0])[:, None], source_idx[:, None]],
                          axis=-1)  # n to 1
    src_pts = np.array(src_kpts.points, dtype=np.float32)[corr[:, 0]]
    tgt_pts = np.array(tgt_kpts.points, dtype=np.float32)[corr[:, 1]]
    src_pts = torch.from_numpy(src_pts).cuda()
    tgt_pts = torch.from_numpy(tgt_pts).cuda()
    t1 = time.time()
    print('1.提取特征FPFH+2.特征匹配 time:', t1 - t0)
    print('--------------------------3.建图--------------------------')
    src_dist = ((src_pts[:, None, :] - src_pts[None, :, :]) ** 2).sum(-1) ** 0.5
    tgt_dist = ((tgt_pts[:, None, :] - tgt_pts[None, :, :]) ** 2).sum(-1) ** 0.5
    cross_dis = torch.abs(src_dist - tgt_dist)
    FCG = torch.clamp(1 - cross_dis ** 2 / 0.1 ** 2, min=0)
    FCG = FCG - torch.diag_embed(torch.diag(FCG))
    FCG[FCG < 0.99] = 0
    # SCG = torch.mm(FCG, FCG) * FCG  # 二维矩阵乘法
    SCG = torch.matmul(FCG, FCG) * FCG  # 多维矩阵乘法0.65 s
    SCG = SCG.cpu().numpy()
    graph = igraph.Graph.Adjacency((SCG > 0).tolist())
    t2 = time.time()
    print('3.建图time:', t2 - t1)
    print('--------------------------4.搜索团--------------------------')
    graph.es['weight'] = SCG[SCG.nonzero()]
    graph.vs['label'] = range(0, corr.shape[0])
    graph.to_undirected()
    macs = graph.maximal_cliques(min=5)   # 参数3, 替换为 c++ 扩展
    t3 = time.time()
    print('4.搜索团time:', t3 - t2)
    print('--------------------------5.后处理:过滤mac--------------------------')
    # ta = time.time()
    # clique_weight = mac_filter.filter(macs, SCG)
    # tb = time.time()
    # print("cython:",tb-ta)

    clique_weight = np.zeros(len(macs), dtype=float)  # mac:list(tuple(int)),SCG :ndarray, 只要能吧这个写成c++扩展,速度就能上去
    for ind in range(len(macs)):
        print("ind:",ind)
        mac = list(macs[ind])
        if len(mac) >= 3:
            for i in range(len(mac)):
                for j in range(i + 1, len(mac)):      # 耗时
                    clique_weight[ind] = clique_weight[ind] + SCG[mac[i], mac[j]]
    tc =time.time()

    clique_ind_of_node = np.ones(corr.shape[0], dtype=int) * -1
    max_clique_weight = np.zeros(corr.shape[0], dtype=float)
    max_size = 3
    for ind in range(len(macs)):
        mac = list(macs[ind])
        weight = clique_weight[ind]
        if weight > 0:
            for i in range(len(mac)):
                if weight > max_clique_weight[mac[i]]:
                    max_clique_weight[mac[i]] = weight
                    clique_ind_of_node[mac[i]] = ind
                    max_size = len(mac) > max_size and len(mac) or max_size
    td = time.time()
    print("td:", td - tc)
    filtered_clique_ind = list(set(clique_ind_of_node))
    print(filtered_clique_ind)
    if -1 in filtered_clique_ind:
        filtered_clique_ind.remove(-1)
    t4 = time.time()
    print('5.后处理time:', t4 - t3)
    print("--------------------------6.团分组--------------------------")
    group = []
    for s in range(3, max_size + 1):
        group.append([])
    for ind in filtered_clique_ind:
        mac = list(macs[ind])
        group[len(mac) - 3].append(ind)

    tensor_list_A = []
    tensor_list_B = []
    for i in range(len(group)):
        if len(group[i]) > 0:
            batch_A = src_pts[list(macs[group[i][0]])][None]
            batch_B = tgt_pts[list(macs[group[i][0]])][None]
            if len(group) == 1:
                continue
            for j in range(1, len(group[i])):
                mac = list(macs[group[i][j]])
                src_corr = src_pts[mac][None]
                tgt_corr = tgt_pts[mac][None]
                batch_A = torch.cat((batch_A, src_corr), 0)
                batch_B = torch.cat((batch_B, tgt_corr), 0)
            tensor_list_A.append(batch_A)
            tensor_list_B.append(batch_B)
    t5 = time.time()
    print('6.团分组time:', t5 - t4)
    print('--------------------------7.计算矩阵--------------------------')
    max_score = 0
    print("len(tensor_list_A):",len(tensor_list_A))
    for i in range(len(tensor_list_A)):
        trans = rigid_transform_3d(tensor_list_A[i], tensor_list_B[i], None, 0)
        for i in range(len(trans)):
            trans_init_icp = trans[i].cpu().numpy()
            reg_p2ps = o3d.pipelines.registration.registration_icp(
                src_kpts, tgt_kpts, 0.005, trans_init_icp,
                #  TransformationEstimationPointToPlane,TransformationEstimationPointToPoint(),TransformationEstimationForGeneralizedICP,TransformationEstimationForColoredICP
                o3d.pipelines.registration.TransformationEstimationPointToPlane(),
                o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=50))

            if reg_p2ps.fitness > max_score:
                max_score = reg_p2ps.fitness
                print('max_score:',max_score)
                mac_trans_init_icp = reg_p2ps.transformation
                source_array = np.array(src_pcd.points).astype(np.float32).reshape(1, -1, 3)
                final_poses = mac_trans_init_icp
                # model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
                # model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
                # model_ts_pcd.paint_uniform_color([1, 0, 0])
                # tgt_pcd.paint_uniform_color([0, 0, 1])
                # o3d.visualization.draw_geometries([model_ts_pcd, tgt_pcd])
    print("icp filter:",time.time() - t5)
    # print('--------------------------9.icp可视化结果--------------------------')
    reg_p2p = o3d.pipelines.registration.registration_icp(
        src_kpts, tgt_kpts, 0.005, mac_trans_init_icp,
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=5000))
    print('final icp fitness:', reg_p2p.fitness)
    final_poses = reg_p2p.transformation
    if visual:
        source_array = np.array(src_pcd.points).astype(np.float32).reshape(1, -1, 3)
        model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
        model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
        model_ts_pcd.paint_uniform_color([1, 0, 0])
        tgt_pcd.paint_uniform_color([0, 0, 1])
        o3d.visualization.draw_geometries([model_ts_pcd, tgt_pcd])
    return final_poses
import time
import numpy as np
import torch
import sys
import random
from ransac import extract_fpfh_features
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial import KDTree
from collections import defaultdict
import trimesh.transformations as tf

sys.path.append('.')
import open3d as o3d
import pickle

import torch
import numpy as np


class Matcher_plus():
    def __init__(self,
                 inlier_threshold=0.10,
                 num_node='all',
                 use_mutual=True,
                 d_thre=0.1,
                 num_iterations=10,
                 ratio=0.2,
                 nms_radius=0.1,
                 max_points=8000,
                 k1=30,
                 k2=20,
                 select_scene=None,
                 FS_TCD_thre=0.05,
                 relax_match_num=100,
                 NS_by_IC=50
                 ):
        self.inlier_threshold = inlier_threshold
        self.num_node = num_node
        self.use_mutual = use_mutual
        self.d_thre = d_thre
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.max_points = max_points
        self.nms_radius = nms_radius
        self.k1 = k1
        self.k2 = k2
        self.FS_TCD_thre = FS_TCD_thre
        self.relax_match_num = relax_match_num
        self.NS_by_IC = NS_by_IC

    def pick_seeds(self, dists, scores, R, max_num):
        """
        Select seeding points using Non Maximum Suppression. (here we only support bs=1)
        Input:
            - dists:       [bs, num_corr, num_corr] src keypoints distance matrix
            - scores:      [bs, num_corr]     initial confidence of each correspondence
            - R:           float              radius of nms
            - max_num:     int                maximum number of returned seeds
        Output:
            - picked_seeds: [bs, num_seeds]   the index to the seeding correspondences
        """
        assert scores.shape[0] == 1

        # parallel Non Maximum Suppression (more efficient)
        score_relation = scores.T >= scores  # [num_corr, num_corr], save the relation of leading_eig
        # score_relation[dists[0] >= R] = 1  # mask out the non-neighborhood node
        score_relation = score_relation.bool() | (dists[0] >= R).bool()
        is_local_max = score_relation.min(-1)[0].float()

        score_local_max = scores * is_local_max
        sorted_score = torch.argsort(score_local_max, dim=1, descending=True)

        # max_num = scores.shape[1]

        return_idx = sorted_score[:, 0: max_num].detach()

        return return_idx

    def cal_seed_trans(self, seeds, SC2_measure, src_keypts, tgt_keypts):
        """
        Calculate the transformation for each seeding correspondences.
        Input:
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - SC2_measure: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - seedwise_trans_relax:       [bs, 4, 4]    the relaxed transformation matrix selected by IC
            - final_trans:       [bs, 4, 4]             best transformation matrix selected by IC
        """
        bs, num_corr, num_channels = SC2_measure.shape[0], SC2_measure.shape[1], SC2_measure.shape[2]
        k1 = self.k1
        k2 = self.k2

        if k1 > num_channels:
            k1 = 4
            k2 = 4

        #################################
        # The first stage consensus set sampling
        # Finding the k1 nearest neighbors around each seed
        #################################
        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx = sorted_score[:, :, 0: k1]
        sorted_value, _ = torch.sort(SC2_measure, dim=2, descending=True)
        idx_tmp = knn_idx.contiguous().view([bs, -1])
        idx_tmp = idx_tmp[:, :, None]
        idx_tmp = idx_tmp.expand(-1, -1, 3)

        #################################
        # construct the local SC2 measure of each consensus subset obtained in the first stage.
        #################################
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])  # [bs, num_seeds, k, 3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])
        src_dist = ((src_knn[:, :, :, None, :] - src_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn[:, :, :, None, :] - tgt_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_SC_measure = (cross_dist < self.d_thre).float()
        local_SC2_measure = torch.matmul(local_hard_SC_measure[:, :, :1, :], local_hard_SC_measure)

        #################################
        # perform second stage consensus set sampling
        #################################
        sorted_score = torch.argsort(local_SC2_measure, dim=3, descending=True)
        knn_idx_fine = sorted_score[:, :, :, 0: k2]

        #################################
        # construct the soft SC2 matrix of the consensus set
        #################################
        num = knn_idx_fine.shape[1]
        knn_idx_fine = knn_idx_fine.contiguous().view([bs, num, -1])[:, :, :, None]
        knn_idx_fine = knn_idx_fine.expand(-1, -1, -1, 3)
        src_knn_fine = src_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])  # [bs, num_seeds, k, 3]
        tgt_knn_fine = tgt_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])

        src_dist = ((src_knn_fine[:, :, :, None, :] - src_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn_fine[:, :, :, None, :] - tgt_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_measure = (cross_dist < self.d_thre * 2).float()
        local_SC2_measure = torch.matmul(local_hard_measure, local_hard_measure) / k2
        local_SC_measure = torch.clamp(1 - cross_dist ** 2 / self.d_thre ** 2, min=0)
        # local_SC2_measure = local_SC_measure * local_SC2_measure
        local_SC2_measure = local_SC_measure
        local_SC2_measure = local_SC2_measure.view([-1, k2, k2])

        #################################
        # Power iteratation to get the inlier probability
        #################################
        local_SC2_measure[:, torch.arange(local_SC2_measure.shape[1]), torch.arange(local_SC2_measure.shape[1])] = 0
        total_weight = self.cal_leading_eigenvector(local_SC2_measure, method='power')
        total_weight = total_weight.view([bs, -1, k2])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        #################################
        # calculate the transformation by weighted least-squares for each subsets in parallel
        #################################
        total_weight = total_weight.view([-1, k2])
        src_knn = src_knn_fine
        tgt_knn = tgt_knn_fine
        src_knn, tgt_knn = src_knn.view([-1, k2, 3]), tgt_knn.view([-1, k2, 3])

        #################################
        # compute the rigid transformation for each seed by the weighted SVD
        #################################
        seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
        seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        #################################
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        ## ###############################
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        seedwise_fitness = torch.sum((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]

        relax_num = self.NS_by_IC
        if relax_num > seedwise_fitness.shape[1]:
            relax_num = seedwise_fitness.shape[1]

        batch_best_guess_relax, batch_best_guess_relax_idx = torch.topk(seedwise_fitness, relax_num)

        batch_best_guess = seedwise_fitness.argmax(dim=1)
        best_guess_ratio = seedwise_fitness[0, batch_best_guess]
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        seedwise_trans_relax = seedwise_trans.gather(dim=1,
                                                     index=batch_best_guess_relax_idx[:, :, None, None].expand(-1, -1,
                                                                                                               4, 4))
        trans_list = seedwise_trans
        return seedwise_trans_relax, final_trans, trans_list

    def cal_leading_eigenvector(self, M, method='power'):
        """
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
        Input:
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
        Output:
            - solution: [bs, num_corr] leading eigenvector
        """
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                if torch.allclose(leading_eig, leading_eig_last):
                    break
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig
        else:
            exit(-1)

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        """
        Calculate the confidence of the spectral matching solution based on spectral analysis.
        Input:
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
        Output:
            - confidence
        """
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, it_num, weights=None):
        """
        Perform post refinement using the initial transformation matrix, only adopted during testing.
        Input
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [K]
            - weights:       [bs, num_corr]
        Output:
            - final_trans:   [bs, 4, 4]
        """
        assert initial_trans.shape[0] == 1
        inlier_threshold = 1.2

        # inlier_threshold_list = [self.inlier_threshold] * it_num

        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * it_num
        else:  # for KITTI
            inlier_threshold_list = [1.2] * it_num

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)

            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                break
            else:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                # weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier]
            )
        return initial_trans

    def match_pair(self, src_keypts, tgt_keypts, src_features, tgt_features):
        """
        Select the best model from the rough models filtered by IC Metric
        Input:
            - src_keypts:  [bs, N, 3]   source point cloud
            - tgt_keypts   [bs, M, 3]   target point cloud
            - src_features  [bs, N,C]  the features of source point cloud
            - tgt_features [bs, M, C]  the features of target point cloud
        Output:
            - src_keypts:  [bs, N, 3]   source point cloud
            - relax_match_points  [1, N, K, 3]  for each source point, we find K target points as the potential correspondences
            - relax_distance [bs, N, K]  feature distance for the relaxed matches
            - src_keypts_corr [bs, N_C, 3]  source points of N_C one-to-one correspondences
            - tgt_keypts_corr [bs, N_C, 3]  target points of N_C one-to-one correspondences
        """

        N_src = src_features.shape[1]
        N_tgt = tgt_features.shape[1]
        # use all point or sample points.
        if self.num_node == 'all':
            src_sel_ind = np.arange(N_src)
            tgt_sel_ind = np.arange(N_tgt)
        else:
            # src_sel_ind = np.random.choice(N_src, self.num_node)
            if self.num_node < N_tgt:
                tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
            else:
                tgt_sel_ind = np.arange(N_tgt)

            if self.num_node < N_src:
                src_sel_ind = np.random.choice(N_src, self.num_node)
            else:
                src_sel_ind = np.arange(N_src)
            # tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
        src_desc = src_features[:, src_sel_ind, :]
        tgt_desc = tgt_features[:, tgt_sel_ind, :]
        src_keypts = src_keypts[:, src_sel_ind, :]
        tgt_keypts = tgt_keypts[:, tgt_sel_ind, :]

        # match points in feature space.
        distance = torch.sqrt(2 - 2 * (src_desc[0] @ tgt_desc[0].T) + 1e-6)
        distance = distance.unsqueeze(0)
        source_idx = torch.argmin(distance[0], dim=1)
        corr = torch.cat([torch.arange(source_idx.shape[0])[:, None].cuda(), source_idx[:, None]], dim=-1)

        # relax_num = distance.shape[1] // 50
        # if relax_num < 100:
        # relax_num = distance.shape[2] // 100
        relax_num = self.relax_match_num
        relax_distance, relax_source_idx = torch.topk(distance, k=relax_num, dim=-1, largest=False)

        relax_source_idx = relax_source_idx.view(relax_source_idx.shape[0], -1)[:, :, None].expand(-1, -1, 3)
        relax_match_points = tgt_keypts.gather(dim=1, index=relax_source_idx).view(relax_source_idx.shape[0], -1,
                                                                                   relax_num, 3)
        # generate correspondences
        src_keypts_corr = src_keypts[:, corr[:, 0]]
        tgt_keypts_corr = tgt_keypts[:, corr[:, 1]]

        return src_keypts, relax_match_points, relax_distance, src_keypts_corr, tgt_keypts_corr

    def select_best_trans(self, seed_trans, src_keypts, relax_match_points, relax_distance, src_keypts_corr,
                          tgt_keypts_corr):

        """
        Select the best model from the rough models filtered by IC Metric
        Input:
            - seed_trans:  [bs, N_s^{'}, 4, 4]   the model selected by IC, N_s^{'} is the number of reserverd transformation
            - src_keypts   [bs, N, 3]   the source point cloud
            - relax_match_points  [1, N, K, 3]  for each source point, we find K target points as the potential correspondences
            - relax_distance [bs, N, K]  feature distance for the relaxed matches
            - src_keypts_corr [bs, N_C, 3]  source points of N_C one-to-one correspondences
            - tgt_keypts_corr [bs, N_C, 3]  target points of N_C one-to-one correspondences
        Output:
            - the best transformation selected by FS-TCD
        """

        seed_num = seed_trans.shape[1]
        # self.inlier_threshold == 0.10: # for 3DMatch

        best_trans = None
        best_fitness = 0

        for i in range(seed_num):
            # 1. refine the transformation by the one-to-one correspondences
            initial_trans = seed_trans[:, i, :, :]
            initial_trans = self.post_refinement(initial_trans, src_keypts_corr, tgt_keypts_corr, 1)

            # 2. use the transformation to project the source point cloud to target point cloud, and find the nearest neighbor
            warped_src_keypts = transform(src_keypts, initial_trans)
            cross_dist = torch.norm((warped_src_keypts[:, :, None, :] - relax_match_points), dim=-1)
            warped_neighbors = (cross_dist <= self.inlier_threshold).float()
            renew_distance = relax_distance + 2 * (cross_dist > self.inlier_threshold * 1.5).float()
            _, mask_min_idx = renew_distance.min(dim=-1)

            # 3. find the correspondences whose alignment error is less than the threshold
            corr = torch.cat([torch.arange(mask_min_idx.shape[1])[:, None].cuda(), mask_min_idx[0][:, None]], dim=-1)
            verify_mask = warped_neighbors
            verify_mask_row = verify_mask.sum(-1) > 0

            # 4. use the spatial consistency to verify the correspondences
            if verify_mask_row.float().sum() > 0:
                verify_mask_row_idx = torch.where(verify_mask_row == True)
                corr_select = corr[verify_mask_row_idx[1]]
                select_relax_match_points = relax_match_points[:, verify_mask_row_idx[1]]
                src_keypts_corr = src_keypts[:, corr_select[:, 0]]
                tgt_keypts_corr = select_relax_match_points.gather(dim=2,
                                                                   index=corr_select[:, 1][None, :, None, None].expand(
                                                                       -1, -1, -1, 3)).squeeze(dim=2)
                src_dist = torch.norm((src_keypts_corr[:, :, None, :] - src_keypts_corr[:, None, :, :]), dim=-1)
                target_dist = torch.norm((tgt_keypts_corr[:, :, None, :] - tgt_keypts_corr[:, None, :, :]), dim=-1)
                corr_compatibility = src_dist - target_dist
                abs_corr_compatibility = torch.abs(corr_compatibility)

                SC_thre = self.FS_TCD_thre
                corr_compatibility_2 = (abs_corr_compatibility < SC_thre).float()
                compatibility_num = torch.sum(corr_compatibility_2, -1)
                renew_fitness = torch.max(compatibility_num)
            else:
                renew_fitness = 0

            if renew_fitness > best_fitness:
                best_trans = initial_trans
                best_fitness = renew_fitness

        return best_trans

    def SC2_PCR(self, src_keypts, tgt_keypts):
        """
        Input:
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
        Output:
            - potential_trans_by_IC:   [bs, 4, 4], the best transformation matrix selected by IC metric.
            - best_trans_by_IC:  [bs, N_s^{'} 4, 4], the potential transformation matrix selected by IC metric.
        """
        bs, num_corr = src_keypts.shape[0], tgt_keypts.shape[1]

        #################################
        # downsample points
        #################################
        if num_corr > self.max_points:
            src_keypts = src_keypts[:, :self.max_points, :]
            tgt_keypts = tgt_keypts[:, :self.max_points, :]
            num_corr = self.max_points

        #################################
        # compute cross dist
        #################################
        src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
        target_dist = torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]), dim=-1)
        cross_dist = torch.abs(src_dist - target_dist)

        #################################
        # compute first order measure
        #################################
        SC_dist_thre = self.d_thre
        SC_measure = torch.clamp(1.0 - cross_dist ** 2 / SC_dist_thre ** 2, min=0)
        hard_SC_measure = (cross_dist < SC_dist_thre).float()

        #################################
        # select reliable seed correspondences
        #################################
        confidence = self.cal_leading_eigenvector(SC_measure, method='power')
        seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))

        #################################
        # compute second order measure
        #################################
        SC2_dist_thre = self.d_thre / 2
        hard_SC_measure_tight = (cross_dist < SC2_dist_thre).float()
        seed_hard_SC_measure = hard_SC_measure.gather(dim=1,
                                                      index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure

        #################################
        # compute the seed-wise transformations and select the best one
        #################################
        potential_trans_by_IC, best_trans_by_IC, trans_list = self.cal_seed_trans(seeds, SC2_measure, src_keypts,
                                                                                  tgt_keypts)

        return potential_trans_by_IC, best_trans_by_IC, trans_list

    def estimator(self, src_keypts, tgt_keypts, src_features, tgt_features):
        """
        Input:
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - src_features: [bs, num_corr, C]
            - tgt_features: [bs, num_corr, C]
        Output:
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix
            - pred_trans:   [bs, num_corr], the predicted inlier/outlier label (0,1)
            - src_keypts_corr:  [bs, num_corr, 3], the source points in the matched correspondences
            - tgt_keypts_corr:  [bs, num_corr, 3], the target points in the matched correspondences
        """
        #################################
        # generate coarse correspondences
        #################################
        src_keypts, relax_match_points, relax_distance, src_keypts_corr, tgt_keypts_corr = self.match_pair(src_keypts,
                                                                                                           tgt_keypts,
                                                                                                           src_features,
                                                                                                           tgt_features)

        #################################
        # use the proposed SC2-PCR to estimate the rigid transformation
        #################################
        seedwise_trans, _, trans_list = self.SC2_PCR(src_keypts_corr, tgt_keypts_corr)

        # select_trans = self.select_best_trans(seedwise_trans, src_keypts, relax_match_points,
        #                                       relax_distance, src_keypts_corr, tgt_keypts_corr)
        #
        # pred_trans = self.post_refinement(select_trans, src_keypts_corr, tgt_keypts_corr, 20)
        #
        # frag1_warp = transform(src_keypts_corr, pred_trans)
        # distance = torch.sum((frag1_warp - tgt_keypts_corr) ** 2, dim=-1) ** 0.5
        # pred_labels = (distance < self.inlier_threshold).float()

        return src_keypts_corr, tgt_keypts_corr, trans_list
        # return pred_trans, pred_labels, src_keypts_corr, tgt_keypts_corr,trans_list


def integrate_trans(R, t):
    """
    Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry.
    Input
        - R: [3, 3] or [bs, 3, 3], rotation matrix
        - t: [3, 1] or [bs, 3, 1], translation matrix
    Output
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    """
    if len(R.shape) == 3:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device)
        else:
            trans = np.eye(4)[None]
        trans[:, :3, :3] = R
        trans[:, :3, 3:4] = t.view([-1, 3, 1])
    else:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4).to(R.device)
        else:
            trans = np.eye(4)
        trans[:3, :3] = R
        trans[:3, 3:4] = t
    return trans


def transform(pts, trans):
    """
    Applies the SE3 transformations, support torch.Tensor and np.ndarry.  Equation: trans_pts = R @ pts + t
    Input
        - pts: [num_pts, 3] or [bs, num_pts, 3], pts to be transformed
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    Output
        - pts: [num_pts, 3] or [bs, num_pts, 3] transformed pts
    """
    if len(pts.shape) == 3:
        trans_pts = trans[:, :3, :3] @ pts.permute(0, 2, 1) + trans[:, :3, 3:4]
        return trans_pts.permute(0, 2, 1)
    else:
        trans_pts = trans[:3, :3] @ pts.T + trans[:3, 3:4]
        return trans_pts.T


def rigid_transform_3d(A, B, weights=None, weight_threshold=0):
    """
    Input:
        - A:       [bs, num_corr, 3], source point cloud
        - B:       [bs, num_corr, 3], target point cloud
        - weights: [bs, num_corr]     weight for each correspondence
        - weight_threshold: float,    clips points with weight below threshold
    Output:
        - R, t
    """
    bs = A.shape[0]
    if weights is None:
        weights = torch.ones_like(A[:, :, 0])
    weights[weights < weight_threshold] = 0
    # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)

    # find mean of point cloud
    centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)
    centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    # construct weight covariance matrix
    Weight = torch.diag_embed(weights)
    H = Am.permute(0, 2, 1) @ Weight @ Bm

    # find rotation
    U, S, Vt = torch.svd(H.cpu())
    U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device)
    delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
    eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device)
    eye[:, -1, -1] = delta_UV
    R = Vt @ eye @ U.permute(0, 2, 1)
    t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1)
    # warp_A = transform(A, integrate_trans(R,t))
    # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean()
    return integrate_trans(R, t)


class Matcher():
    def __init__(self,
                 inlier_threshold=0.10,
                 num_node='all',
                 use_mutual=True,
                 d_thre=0.1,
                 num_iterations=10,
                 ratio=0.2,
                 nms_radius=0.1,
                 max_points=8000,
                 k1=30,
                 k2=20,
                 select_scene=None,
                 ):
        self.inlier_threshold = inlier_threshold
        self.num_node = num_node
        self.use_mutual = use_mutual
        self.d_thre = d_thre
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.max_points = max_points
        self.nms_radius = nms_radius
        self.k1 = k1
        self.k2 = k2

    def pick_seeds(self, dists, scores, R, max_num):
        """
        Select seeding points using Non Maximum Suppression. (here we only support bs=1)
        Input:
            - dists:       [bs, num_corr, num_corr] src keypoints distance matrix
            - scores:      [bs, num_corr]     initial confidence of each correspondence
            - R:           float              radius of nms
            - max_num:     int                maximum number of returned seeds
        Output:
            - picked_seeds: [bs, num_seeds]   the index to the seeding correspondences
        """
        assert scores.shape[0] == 1

        # parallel Non Maximum Suppression (more efficient)
        score_relation = scores.T >= scores  # [num_corr, num_corr], save the relation of leading_eig
        # score_relation[dists[0] >= R] = 1  # mask out the non-neighborhood node
        score_relation = score_relation.bool() | (dists[0] >= R).bool()
        is_local_max = score_relation.min(-1)[0].float()

        score_local_max = scores * is_local_max
        sorted_score = torch.argsort(score_local_max, dim=1, descending=True)

        # max_num = scores.shape[1]

        return_idx = sorted_score[:, 0: max_num].detach()

        return return_idx

    def cal_seed_trans(self, seeds, SC2_measure, src_keypts, tgt_keypts):
        """
        Calculate the transformation for each seeding correspondences.
        Input:
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - SC2_measure: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - final_trans:       [bs, 4, 4]             best transformation matrix (after post refinement) for each batch.
        """
        bs, num_corr, num_channels = SC2_measure.shape[0], SC2_measure.shape[1], SC2_measure.shape[2]
        k1 = self.k1
        k2 = self.k2

        if k1 > num_channels:
            k1 = 4
            k2 = 4

        #################################
        # The first stage consensus set sampling
        # Finding the k1 nearest neighbors around each seed
        #################################
        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx = sorted_score[:, :, 0: k1]
        sorted_value, _ = torch.sort(SC2_measure, dim=2, descending=True)
        idx_tmp = knn_idx.contiguous().view([bs, -1])
        idx_tmp = idx_tmp[:, :, None]
        idx_tmp = idx_tmp.expand(-1, -1, 3)

        #################################
        # construct the local SC2 measure of each consensus subset obtained in the first stage.
        #################################
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])  # [bs, num_seeds, k, 3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])
        src_dist = ((src_knn[:, :, :, None, :] - src_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn[:, :, :, None, :] - tgt_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_SC_measure = (cross_dist < self.d_thre).float()
        local_SC2_measure = torch.matmul(local_hard_SC_measure[:, :, :1, :], local_hard_SC_measure)

        #################################
        # perform second stage consensus set sampling
        #################################
        sorted_score = torch.argsort(local_SC2_measure, dim=3, descending=True)
        knn_idx_fine = sorted_score[:, :, :, 0: k2]

        #################################
        # construct the soft SC2 matrix of the consensus set
        #################################
        num = knn_idx_fine.shape[1]
        knn_idx_fine = knn_idx_fine.contiguous().view([bs, num, -1])[:, :, :, None]
        knn_idx_fine = knn_idx_fine.expand(-1, -1, -1, 3)
        src_knn_fine = src_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])  # [bs, num_seeds, k, 3]
        tgt_knn_fine = tgt_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])

        src_dist = ((src_knn_fine[:, :, :, None, :] - src_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn_fine[:, :, :, None, :] - tgt_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_measure = (cross_dist < self.d_thre * 2).float()
        local_SC2_measure = torch.matmul(local_hard_measure, local_hard_measure) / k2
        local_SC_measure = torch.clamp(1 - cross_dist ** 2 / self.d_thre ** 2, min=0)
        # local_SC2_measure = local_SC_measure * local_SC2_measure
        local_SC2_measure = local_SC_measure
        local_SC2_measure = local_SC2_measure.view([-1, k2, k2])

        #################################
        # Power iteratation to get the inlier probability
        #################################
        local_SC2_measure[:, torch.arange(local_SC2_measure.shape[1]), torch.arange(local_SC2_measure.shape[1])] = 0
        total_weight = self.cal_leading_eigenvector(local_SC2_measure, method='power')
        total_weight = total_weight.view([bs, -1, k2])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        #################################
        # calculate the transformation by weighted least-squares for each subsets in parallel
        #################################
        total_weight = total_weight.view([-1, k2])
        src_knn = src_knn_fine
        tgt_knn = tgt_knn_fine
        src_knn, tgt_knn = src_knn.view([-1, k2, 3]), tgt_knn.view([-1, k2, 3])

        #################################
        # compute the rigid transformation for each seed by the weighted SVD
        #################################
        seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
        seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        #################################
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        #################################
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        seedwise_fitness = torch.sum((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]
        batch_best_guess = seedwise_fitness.argmax(dim=1)
        best_guess_ratio = seedwise_fitness[0, batch_best_guess]
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        final_trans_list = seedwise_trans
        return final_trans_list, final_trans

    def cal_leading_eigenvector(self, M, method='power'):
        """
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
        Input:
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
        Output:
            - solution: [bs, num_corr] leading eigenvector
        """
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                if torch.allclose(leading_eig, leading_eig_last):
                    break
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig
        else:
            exit(-1)

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        """
        Calculate the confidence of the spectral matching solution based on spectral analysis.
        Input:
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
        Output:
            - confidence
        """
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, it_num, weights=None):
        """
        Perform post refinement using the initial transformation matrix, only adopted during testing.
        Input
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [bs, num_corr]
        Output:
            - final_trans:   [bs, 4, 4]
        """
        assert initial_trans.shape[0] == 1
        inlier_threshold = 1.2

        # inlier_threshold_list = [self.inlier_threshold] * it_num

        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * it_num
        else:  # for KITTI
            inlier_threshold_list = [1.2] * it_num

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)

            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                break
            else:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                # weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier]
            )
        return initial_trans

    def match_pair(self, src_keypts, tgt_keypts, src_features, tgt_features):
        N_src = src_features.shape[1]
        N_tgt = tgt_features.shape[1]
        # use all point or sample points.
        if self.num_node == 'all':
            src_sel_ind = np.arange(N_src)
            tgt_sel_ind = np.arange(N_tgt)
        else:
            src_sel_ind = np.random.choice(N_src, self.num_node)
            tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
        src_desc = src_features[:, src_sel_ind, :]
        tgt_desc = tgt_features[:, tgt_sel_ind, :]
        src_keypts = src_keypts[:, src_sel_ind, :]
        tgt_keypts = tgt_keypts[:, tgt_sel_ind, :]

        # match points in feature space.
        distance = torch.sqrt(2 - 2 * (src_desc[0] @ tgt_desc[0].T) + 1e-6)
        # distance = torch.abs(src_desc[0] @ tgt_desc[0].T)  # 内积衡量两特征相似性
        distance = distance.unsqueeze(0)
        source_idx = torch.argmin(distance[0], dim=1)  # source和tat最相似的index
        corr = torch.cat([torch.arange(source_idx.shape[0])[:, None].cuda(), source_idx[:, None]], dim=-1)
        print("src-tgt的点的相关:", corr)
        # generate correspondences
        src_keypts_corr = src_keypts[:, corr[:, 0]]
        tgt_keypts_corr = tgt_keypts[:, corr[:, 1]]

        return src_keypts_corr, tgt_keypts_corr

    def SC2_PCR(self, src_keypts, tgt_keypts):
        """
        Input:
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
        Output:
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix.
            - pred_labels:  [bs, num_corr], the predicted inlier/outlier label (0,1), for classification loss calculation.
        """
        bs, num_corr = src_keypts.shape[0], tgt_keypts.shape[1]

        #################################
        # downsample points
        #################################
        if num_corr > self.max_points:
            src_keypts = src_keypts[:, :self.max_points, :]
            tgt_keypts = tgt_keypts[:, :self.max_points, :]
            num_corr = self.max_points

        #################################
        # compute cross dist
        #################################
        src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
        target_dist = torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]), dim=-1)
        cross_dist = torch.abs(src_dist - target_dist)

        #################################
        # compute first order measure
        #################################
        SC_dist_thre = self.d_thre
        SC_measure = torch.clamp(1.0 - cross_dist ** 2 / SC_dist_thre ** 2, min=0)
        hard_SC_measure = (cross_dist < SC_dist_thre).float()

        #################################
        # select reliable seed correspondences
        #################################
        confidence = self.cal_leading_eigenvector(SC_measure, method='power')
        seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))

        #################################
        # compute second order measure
        #################################
        SC2_dist_thre = self.d_thre / 2
        hard_SC_measure_tight = (cross_dist < SC2_dist_thre).float()
        seed_hard_SC_measure = hard_SC_measure.gather(dim=1,
                                                      index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure

        #################################
        # compute the seed-wise transformations and select the best one
        #################################
        final_trans_list, final_trans = self.cal_seed_trans(seeds, SC2_measure, src_keypts, tgt_keypts)
        final_trans_list = final_trans_list[0]
        #################################
        # refine the result by recomputing the transformation over the whole set
        #################################
        final_trans_refine_list = []
        for i in range(len(final_trans_list)):
            final_trans = final_trans_list[i].reshape(1, 4, 4)
            refine_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts, 20)
            final_trans_refine_list.append(refine_trans)
        return final_trans_list, final_trans_refine_list, final_trans

    def estimator(self, src_keypts, tgt_keypts, src_features, tgt_features):
        """
        Input:
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - src_features: [bs, num_corr, C]
            - tgt_features: [bs, num_corr, C]
        Output:
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix
            - pred_trans:   [bs, num_corr], the predicted inlier/outlier label (0,1)
            - src_keypts_corr:  [bs, num_corr, 3], the source points in the matched correspondences
            - tgt_keypts_corr:  [bs, num_corr, 3], the target points in the matched correspondences
        """
        #################################src_keypts_corr = {Tensor: (1, 1013, 3)} tensor([[[ 0.0136,  0.0483, -0.0135],\n         [ 0.0240, -0.0507,  0.0015],\n         [ 0.0271, -0.0508,  0.0016],\n         ...,\n         [ 0.0135,  0.0030, -0.0171],\n         [-0.0215,  0.0080, -0.0170],\n         [-0.0065,  0.0080, -0.0166]]], device='cuda... View
        # generate coarse correspondences
        #################################
        src_keypts_corr, tgt_keypts_corr = self.match_pair(src_keypts, tgt_keypts, src_features, tgt_features)

        #################################
        # use the proposed SC2-PCR to estimate the rigid transformation
        #################################
        final_trans_list, final_trans_refine_list, final_trans = self.SC2_PCR(src_keypts_corr, tgt_keypts_corr)
        print("final_trans_list:", final_trans_list)
        print("final_trans_refine_list:", final_trans_refine_list)
        # frag1_warp = transform(src_keypts_corr, pred_trans)
        # distance = torch.sum((frag1_warp - tgt_keypts_corr) ** 2, dim=-1) ** 0.5
        # pred_labels = (distance < self.inlier_threshold).float()

        return final_trans_list, final_trans_refine_list, final_trans
        # return pred_trans, pred_labels, src_keypts_corr, tgt_keypts_corr


def set_seed(seed=51):
    """
    Set the random seed for reproduce the results.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_3D_bbox(pcld_ary, small=False):
    min_x, max_x = pcld_ary[:, 0].min(), pcld_ary[:, 0].max()
    min_y, max_y = pcld_ary[:, 1].min(), pcld_ary[:, 1].max()
    min_z, max_z = pcld_ary[:, 2].min(), pcld_ary[:, 2].max()
    bbox = np.array([
        [min_x, min_y, min_z],
        [min_x, min_y, max_z],
        [min_x, max_y, min_z],
        [min_x, max_y, max_z],
        [max_x, min_y, min_z],
        [max_x, min_y, max_z],
        [max_x, max_y, min_z],
        [max_x, max_y, max_z],
    ])
    if small:
        center = np.mean(bbox, 0)
        bbox = (bbox - center[None, :]) * 2.0 / 3.0 + center[None, :]
    return bbox


def bernstein(vala, valb):
    """thanks to special sauce https://stackoverflow.com/a/34006336/10059727"""
    h = 1009
    h = h * 9176 + vala
    h = h * 9176 + valb
    return h


def average_rotations(rotations):
    """thanks to jonathan https://stackoverflow.com/a/27410865/10059727"""
    Q = np.zeros((4, len(rotations)))

    for i, rot in enumerate(rotations):
        quat = tf.quaternion_from_matrix(rot)
        Q[:, i] = quat

    _, v = np.linalg.eigh(Q @ Q.T)
    quat_avg = v[:, -1]

    return tf.quaternion_matrix(quat_avg)


def cluster_poses(poses, dist_max=0.5, rot_max_deg=10, pdist_rot=None):
    rots = np.array([T_m2s[:3, :3] for T_m2s, _, _ in poses])
    locs = np.array([T_m2s[:3, 3] for T_m2s, _, _ in poses])
    scores = np.array([score for _, _, score in poses])

    method = "centroid"

    # 1) cluster by location
    dist_dists = pdist(locs)
    dist_dendro = linkage(dist_dists, method)
    dist_clusters = fcluster(dist_dendro, dist_max, criterion="distance")

    # 2) cluster by rotations
    # XXX optimize, we can make more smaller cluster problems, since
    # a cluster across distant poses doesn't make sense
    rot_dists = pdist_rot(rots)
    rot_dendor = linkage(rot_dists, method)
    rot_clusters = fcluster(rot_dendor, rot_max_deg, criterion="distance")

    # Combine the two clusterings, by creating new clusters
    # if two poses are in same cluster in loc and rot, they will be in new
    # common cluster (hash of both cluster ids)
    pose_clusters = bernstein(dist_clusters, rot_clusters)

    # remap the ludicrous hash values to range 0..num
    _, pose_clusters = np.unique(pose_clusters, return_inverse=True)

    cluster_scores = np.zeros(np.max(pose_clusters) + 1)
    for pose_score, pose_cluster in zip(scores, pose_clusters):
        cluster_scores[pose_cluster] += pose_score

    best_cluster_idx = np.argmax(cluster_scores)
    print(
        "Best cluster",
        best_cluster_idx,
        cluster_scores[best_cluster_idx],
        np.count_nonzero(pose_clusters == best_cluster_idx),
    )

    # plt.hist(cluster_scores, histtype="stepfilled", bins=100)
    # plt.title("Cluster Scores Histogram")
    # plt.show()

    out_ts = defaultdict(list)
    out_Rs = defaultdict(list)
    for pose_idx, cluster_idx in enumerate(pose_clusters):
        out_ts[cluster_idx].append(locs[pose_idx])
        out_Rs[cluster_idx].append(rots[pose_idx])

    sorted_clusters = np.argsort(cluster_scores)[::-1]

    for idx_cluster, top_cluster in enumerate(sorted_clusters[:10]):
        print(f"cluster {idx_cluster} contains {len(out_ts[top_cluster])} poses")

    geo = lambda x, y: np.sqrt(x * y)

    best_cluster_idx = np.argmax(cluster_scores)
    best_score = cluster_scores[best_cluster_idx]
    best_geo_score = geo(best_score, len(out_ts[best_cluster_idx]))
    best_rel_thresh = 0.5

    out_poses = []
    for cluster_idx in sorted_clusters:
        geo_score = geo(len(out_ts[cluster_idx]), cluster_scores[cluster_idx])
        if geo_score < best_rel_thresh * best_geo_score:
            continue

        print("cluster idx", cluster_idx, cluster_scores[cluster_idx], "geoscore", geo_score)
        ts = out_ts[cluster_idx]
        Rs = out_Rs[cluster_idx]

        avg_t = np.mean(ts, axis=0)
        avg_R = average_rotations(Rs)
        out_T = np.eye(4)
        out_T[:3, :3] = avg_R[:3, :3]
        out_T[:3, 3] = avg_t
        out_poses.append((out_T, 0, geo_score))

    print("Returning", len(out_poses), "clustered and averaged poses")
    return out_poses


def paxini_sc2_plus(model_down_pcd, scene_down_pcd, inlier_threshold, num_node,
                    use_mutual, d_thre, num_iterations, ratio, nms_radius, max_points, k1, k2, visual=True):
    t0 = time.time()
    model_down_pcd = model_down_pcd.farthest_point_down_sample(1000)  # 3-500 points,模型多点,场景可少点
    # model_down_pcd = model_down_pcd.voxel_down_sample(0.004)  # 3-500 points,模型多点,场景可少点
    scene_down_pcd = scene_down_pcd.voxel_down_sample(0.0025)
    # print("model数量:", len(model_down_pcd.points))
    # print("scene数量:", len(scene_down_pcd.points))
    model_down_array = np.array(model_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    scene_down_array = np.array(scene_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    model_down_tensor = torch.from_numpy(model_down_array).cuda()
    scene_down_tensor = torch.from_numpy(scene_down_array).cuda()
    # print('---------------2.compute pcd local features-----------------------------------------')
    src_features = extract_fpfh_features(model_down_pcd, 0.05)
    tgt_features = extract_fpfh_features(scene_down_pcd, 0.05)
    src_features = src_features.reshape(1, -1, 33)
    tgt_features = tgt_features.reshape(1, -1, 33)
    src_features_tensor = torch.from_numpy(src_features).cuda()
    tgt_features_tensor = torch.from_numpy(tgt_features).cuda()
    # print('---------------3. match descriptor && compute rigid transformation----')
    matcher_plus = Matcher_plus(inlier_threshold=inlier_threshold, num_node=num_node, use_mutual=use_mutual,
                                d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                                max_points=max_points, k1=k1, k2=k2)
    src_keypts_corr, tgt_keypts_corr, trans_list = matcher_plus.estimator(model_down_tensor, scene_down_tensor,
                                                                          src_features_tensor, tgt_features_tensor)
    # print("------------------------4.icp filter -----------------------------------------")
    final_poses_list = trans_list.detach().cpu().numpy()[0]  # pose 聚类?
    # print('final_poses_list:', len(final_poses_list))
    # pose cluster:
    # out_poses = cluster_poses(final_poses_list, dist_max=0.5, rot_max_deg=10, pdist_rot=None)
    # print('out_poses:', len(out_poses))
    max_fitness = -1
    # min_rmse = 10
    for i in range(len(final_poses_list)):
        pose_init = final_poses_list[i]
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.001, pose_init,
            o3d.pipelines.registration.TransformationEstimationPointToPlane(),
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            # min_rmse = reg_p2p.inlier_rmse
            sc2_final_pose = reg_p2p.transformation
            # print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            # print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)

    reg_p2ps = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.01, sc2_final_pose,
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1000))
    sc2_final_pose = reg_p2ps.transformation
    print("sc2 运行时间:", time.time() - t0)
    if visual:
        model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
        ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
        ts_pcd.paint_uniform_color([1, 0, 0])
        scene_down_pcd.paint_uniform_color([0, 0, 1])
        o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd])
    return sc2_final_pose


def paxini_sc2(model_down_pcd0, scene_down_pcd0, inlier_threshold, num_node,
               use_mutual, d_thre, num_iterations, ratio, nms_radius,
               max_points, k1, k2):
    model_down_pcd = model_down_pcd0.voxel_down_sample(0.0025)  # 3-500 points,模型多点,场景可少点
    scene_down_pcd = scene_down_pcd0.voxel_down_sample(0.0025)
    print("model数量:", len(model_down_pcd.points))
    print("scene数量:", len(scene_down_pcd.points))
    model_down_array = np.array(model_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    scene_down_array = np.array(scene_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    model_down_tensor = torch.from_numpy(model_down_array).cuda()
    scene_down_tensor = torch.from_numpy(scene_down_array).cuda()
    print('---------------2.compute pcd local features-----------------------------------------')
    src_features = extract_fpfh_features(model_down_pcd, 0.05)
    tgt_features = extract_fpfh_features(scene_down_pcd, 0.05)
    src_features = src_features.reshape(1, -1, 33)
    tgt_features = tgt_features.reshape(1, -1, 33)
    src_features_tensor = torch.from_numpy(src_features).cuda()
    tgt_features_tensor = torch.from_numpy(tgt_features).cuda()

    print('---------------3. match descriptor && compute rigid transformation----')
    matcher = Matcher(inlier_threshold=inlier_threshold, num_node=num_node, use_mutual=use_mutual,
                      d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                      max_points=max_points, k1=k1, k2=k2)
    final_trans_list, final_trans_refine_list, final_trans = matcher.estimator(model_down_tensor, scene_down_tensor,
                                                                               src_features_tensor,
                                                                               tgt_features_tensor)
    reg_p = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.02, final_trans[0].cpu().numpy(),
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=100))
    sc2_final_pose = reg_p.transformation
    model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
    ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    ts_pcd.paint_uniform_color([1, 0, 0])
    scene_down_pcd0.paint_uniform_color([0, 0, 1])
    o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    print("------------------------4.icp filter -----------------------------------------")
    final_poses_list = final_trans_list.detach().cpu().numpy()
    print('final_poses_list:', len(final_poses_list))
    max_fitness = -1
    min_rmse = 10
    for i in range(len(final_poses_list)):
        pose_init = final_poses_list[i]
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.05, pose_init,
            o3d.pipelines.registration.TransformationEstimationPointToPlane(),
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1000))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            min_rmse = reg_p2p.inlier_rmse
            sc2_final_pose = reg_p2p.transformation
            print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)
            model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
            ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
            ts_pcd.paint_uniform_color([1, 0, 0])
            scene_down_pcd0.paint_uniform_color([0, 0, 1])
            o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    print("--------final_trans_refine_list--------------------------")
    min_rmse = 10
    for i in range(len(final_trans_refine_list)):
        pose_init = final_trans_refine_list[i].cpu().numpy().reshape(4, 4)
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.02, pose_init,
            o3d.pipelines.registration.TransformationEstimationPointToPlane(),
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            min_rmse = reg_p2p.inlier_rmse
            sc2_final_refine_pose = reg_p2p.transformation
            print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)
            model_ts = np.dot(model_down_array, sc2_final_refine_pose[:3, :3].T) + sc2_final_refine_pose[:3, 3]
            ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
            ts_pcd.paint_uniform_color([1, 0, 0])
            scene_down_pcd0.paint_uniform_color([0, 0, 1])
            o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    print('------------------------5.icp refine -----------------------------------------')
    result_sc2_icp_final = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.02, sc2_final_pose,
        o3d.pipelines.registration.TransformationEstimationPointToPlane(),
        o3d.pipelines.registration.ICPConvergenceCriteria(5000))
    final_poses = result_sc2_icp_final.transformation
    confidence = result_sc2_icp_final.fitness
    print("confidence:", confidence)
    print('------------------------6.visual---------------------------------------')
    # model_ts = np.dot(model_down_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    # ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    # o3d.visualization.draw_geometries([ts_pcd,scene_down_pcd0])           # 点云太对会画不出来
    return final_poses


if __name__ == '__main__':
    set_seed()
    t0 = time.time()
    inlier_threshold = 0.001
    num_node = 'all'
    use_mutual = True
    num_iterations = 10
    ratio = 0.2
    nms_radius = 0.1
    max_points = 8000
    k1 = 30  # Finding the k1 nearest neighbors around each seed
    k2 = 20
    voxel_size = 0.001  # 下采样比例
    d_thre = voxel_size * 2

    print('---------------1.load pcd---------------------------------------------')
    model_pcd_path = r'D:\codes_pose\PPF_Project\3D-Registration-with-Maximal-Cliques\Python_implement\test\carton_02.ply'  # soap_01,,coconut_00,markers_03
    pth = open(r'D:\codes_pose\PPF_Project\ppf-original\model-globally-match-locally-python\example_models'
               r'\paxini_data\dep\0134.pickle',
               'rb')  # soap_01--0000pkl, coconut_00--0002.pickle; carton_02--0134.pkl,marker_pose-200pkl
    model_pcd = o3d.io.read_point_cloud(model_pcd_path)
    model_array = np.hstack((np.array(model_pcd.points), np.array(model_pcd.normals))).astype(np.float32)
    # obj_pose  = np.array([[-0.7287485250189922, -0.6846525505303752, -0.013284289024821653, -0.054847823497723394],  #soap_pose
    #                      [-0.662871399507191, 0.7101669367340677, -0.23720124300923914, -0.04771079338428959],
    #                      [0.17183449885869845, -0.16405428071831307, -0.9713697020084459, 0.5589560143247371],
    #                      [0, 0, 0, 1]])
    obj_pose = np.array(
        [[-0.006038161573090156, -0.9989707481405733, -0.04495536635686892, -0.06520650927334047],  # carton_pose
         [-0.9191053187973984, 0.02325571354051051, -0.393325037019072, -0.06357232357506885],
         [0.3939656756154652, 0.038943756202820916, -0.9182997496948725, 0.4651729732704407],
         [0, 0, 0, 1]])
    # coconut_pose = np.array(
    #     [[0.6237558392913077, -0.7806701285279535, -0.038507186006563776, 0.1410553384970578],  # coconut
    #      [-0.7447691171729245, -0.5786779007489092, -0.33234146489762334, -0.11014853879504719],
    #      [0.2371657965547699, 0.2359788922954422, -0.9423727220880057, 0.5867524362173748],
    #      [0, 0, 0, 1]])
    #  = np.array([[-0.1876610060907363, -0.9820621511025337, -0.0183651344911663, -0.0641744048368183],  # marker_pose
    #                      [-0.9813082289108378, 0.18826291187839725, -0.039890298105449536, -0.09830046364764121],
    #                      [0.042632225661904416, 0.010536004125507674, -0.9990352776314657, 0.44157644191646106],
    #                      [0, 0, 0, 1]])
    # src_pcd.transform(obj_pose)
    depth = np.array(pickle.load(pth))
    depth_array = np.array(depth)
    depth = o3d.geometry.Image(depth_array)
    intrinsic = o3d.camera.PinholeCameraIntrinsic(width=depth_array.shape[1], height=depth_array.shape[0],
                                                  fx=616.58, fy=616.778, cx=323.103, cy=238.464)
    tgt_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth, intrinsic, depth_scale=1000,
                                                              depth_trunc=1000.0)  # depth_scale=1000后单位也是米
    print('obj_pose:', obj_pose)
    marker_r = np.array(obj_pose[:3, :3])  # 物体相对相机的pose
    marker_t = np.array(obj_pose[:3, 3])
    bbox = get_3D_bbox(model_array)
    bbox *= 1.
    transformed_bbox = np.dot(bbox, marker_r.T) + marker_t
    cropped_pcd = tgt_pcd.crop(
        o3d.geometry.AxisAlignedBoundingBox(min_bound=np.min(transformed_bbox, axis=0),
                                            max_bound=np.max(transformed_bbox, axis=0)))
    cropped_pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))

    model_pcd = model_pcd.transform(obj_pose)  # 直接先变换
    o3d.visualization.draw_geometries([model_pcd, cropped_pcd])

    model_down_pcd = model_pcd.voxel_down_sample(voxel_size)
    scene_down_pcd = cropped_pcd.voxel_down_sample(voxel_size)
    print("model数量:", len(model_down_pcd.points))
    print("scene数量:", len(scene_down_pcd.points))

    final_poses = paxini_sc2(model_down_pcd, scene_down_pcd, inlier_threshold=inlier_threshold, num_node=num_node,
                             use_mutual=use_mutual,
                             d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                             max_points=max_points, k1=k1, k2=k2)

3.ref

https://arxiv.org/pdf/2203.14453.pdf

https://arxiv.org/pdf/2305.10854.pdf

https://github.com/zhangxy0517/3D-Registration-with-Maximal-Cliques c++

  • 20
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
深度学习点云配准方面取得了显著的进展。点云配准是将多个点云数据对齐以形成一个整体的过程,它在计算机视觉、机器人学和三维重建等领域有着广泛的应用。 传统的点云配准方法通常基于特征提取和优化算法,如ICP(Iterative Closest Point)和RANSAC(Random Sample Consensus)。然而,这些方法对于大规模和噪声较多的点云数据处理效果不佳。 深度学习方法通过神经网络的学习能力,可以从原始的点云数据中提取出高级的特征表示,从而实现更准确和鲁棒的配准效果。以下是一些常见的深度学习点云配准方法: 1. PointNet: PointNet是一种基于神经网络的点云处理框架,它通过对点云进行全局特征提取和局部特征提取,实现了对点云的鲁棒描述和匹配。 2. PointNet++: PointNet++是PointNet的扩展版本,它通过使用层次化的神经网络结构,能够更好地捕捉点云数据的层次结构和上下文信息,进一步提升了点云配准的准确性和稳定性。 3. DGCNN: DGCNN(Dynamic Graph CNN)是一种基于图卷积神经网络的点云配准方法,它通过构建点云之间的邻接图,并在图上进行卷积操作,实现了对点云数据的特征学习和匹配。 4. 3DMatch: 3DMatch是一种用于点云配准深度学习方法,它通过将点云数据转化为体素表示,并使用3D卷积神经网络进行特征学习和匹配,实现了对大规模点云数据的高效配准。 这些方法在点云配准任务中取得了不错的效果,但仍存在一些挑战,如对噪声和遮挡的鲁棒性不足,对大规模数据的处理效率较低等。未来的研究方向可能包括改进网络结构、提升鲁棒性和效率,并将深度学习与传统方法相结合,进一步推动点云配准领域的发展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值