




2.An overview of the 3D point cloud registration methods

2.1 Feature-based registration methods

2.2. Cloud-to-cloud fine registration methods

2.2.1. ICP-derived methods

2.2.2. Probabilistic methods




思路:为了减少标注量,做一些3D的配准,这样只需要标注一帧图像,然后通过配准找到两帧图像之间的相机的位姿变换矩阵,然后原始标注数据*位姿变换矩阵 = new 一帧的标注数据

paper: https://arxiv.org/ftp/arxiv/papers/2302/2302.07184.pdf



点云配准(point cloud registration)的两个主要步骤:

  1. 粗配准(coarse registration):使用稀疏特征点对应关系建立两个数据集之间的初始几何变换。这个步骤通常使用一些特征提取和匹配算法来寻找两个点云之间的相似性,例如SIFT、SURF、ORB等算法。这个步骤的目的是将两个点云的初始位置和姿态对齐,为后续的精细配准提供一个良好的初始估计。

  2. 精配准(fine registration):使用更多和更密集的对应关系来进一步优化几何变换,也称为点云对点云(cloud-to-cloud,C2C)配准。这个步骤通常使用一些迭代最近点(Iterative Closest Point,ICP)算法或其变种来寻找两个点云之间的最优变换。这个步骤的目的是尽可能地减小两个点云之间的距离和误差,使它们在几何上完全对齐。

  3. feature-based coarse registration and the C2C fine registration methods,

2.An overview of the 3D point cloud registration methods

2.1 Feature-based registration methods









2.2. Cloud-to-cloud fine registration methods


2.2.1. ICP-derived methods







2.2.2. Probabilistic methods




其中一种广泛使用的概率方法是一致点漂移(Coherent Point Drift,CPD),它可以用于刚性和非刚性点云配准。CPD将点云配准问题视为概率密度估计问题,其中一个点云被假设为GMM模型,匹配点云则是变换的样本,通过最大似然估计来进行配准。CPD使用近似方法来加速计算,将计算复杂度降低到线性时间。




import open3d as o3d
import numpy as np
import cv2
import time
import torch
import igraph
import  math
from scipy.spatial.transform import  Rotation
def calculate_inv(RT):
    R = RT[:3, :3]
    T = RT[:3, 3]
    T_inv = -np.matmul(R.transpose(), T)
    inv_RT = np.eye(4)
    inv_RT[:3, :3] = R.transpose()
    inv_RT[:3, 3] = T_inv
    return inv_RT

def rot2euler(rot):
    r = Rotation.from_matrix(rot)
    rotation = r.as_euler('XYZ')
    return rotation

def ransac_icp(source, target):
    source_array = np.array(source.points).astype(np.float32).reshape(1, -1, 3)
    source_fpfh = o3d.pipelines.registration.compute_fpfh_feature(source,
    target_fpfh = o3d.pipelines.registration.compute_fpfh_feature(target,
    result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
        source, target, source_fpfh, target_fpfh, True,
         ], o3d.pipelines.registration.RANSACConvergenceCriteria(100000,  # 两个参数影响时间
    trans_init_icp = result_ransac.transformation
    result_ransac_icp = o3d.pipelines.registration.registration_icp(
        source, target, 0.02, trans_init_icp,
    final_poses = result_ransac_icp.transformation
    model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    o3d.visualization.draw_geometries([model_ts_pcd, target])
    return final_poses

def opencv_ppf_icp(source, target):
    source_down = source
    target_down = target
    source_array = np.array(source.points).astype(np.float32).reshape(1, -1, 3)
    model_array = np.hstack((np.array(source_down.points), np.array(source_down.normals))).astype(np.float32)
    detector = cv2.ppf_match_3d_PPF3DDetector(0.025, 0.05)
    scene_array = np.hstack((np.array(target_down.points), np.array(target_down.normals))).astype(
    result_ppfs = detector.match(scene_array, 0.025, 0.05)  # list
    result_ppf = result_ppfs[0]
    print("numVotes:", result_ppf.numVotes)
    print("residual:", result_ppf.residual)
    ts_init = result_ppf.pose
    result_icp = open3d_icp(source, target, ts_init, threshold=0.02, max_iteration=5000)
    final_poses = result_icp.transformation
    model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    o3d.visualization.draw_geometries([model_ts_pcd, target])
    return final_poses

def open3d_icp(source, target, trans_init, threshold, max_iteration):
    t1 = time.time()  # 2.5, 点数多效果好但时间慢,结果似乎取决于模型的数量,应该是两者同比例下采样,场景bbox大点好
    reg_p2p = o3d.pipelines.registration.registration_icp(
        source, target, threshold, trans_init,
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))  # iterations (30 by default)
    print("*****icp的时间:", time.time() - t1)
    print('*****icp的transformation:', reg_p2p.transformation)
    print('*****icp的.fitness:', reg_p2p.fitness)
    return reg_p2p

def extract_fpfh_features(keypts, downsample):
    keypts.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=downsample * 2, max_nn=30))
    features = o3d.pipelines.registration.compute_fpfh_feature(keypts, o3d.geometry.KDTreeSearchParamHybrid(
        radius=downsample * 5, max_nn=100))
    features = np.array(features.data).T
    features = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)
    return features

def rigid_transform_3d(A, B, weights=None, weight_threshold=0):
        - A:       [bs, num_corr, 3], source point cloud
        - B:       [bs, num_corr, 3], target point cloud
        - weights: [bs, num_corr]     weight for each correspondence
        - weight_threshold: float,    clips points with weight below threshold
        - R, t
    bs = A.shape[0]
    if weights is None:
        weights = torch.ones_like(A[:, :, 0])
    weights[weights < weight_threshold] = 0
    # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)

    # find mean of point cloud
    centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)
    centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    # construct weight covariance matrix
    Weight = torch.diag_embed(weights)  # 升维度,然后变为对角阵
    H = Am.permute(0, 2, 1) @ Weight @ Bm  # permute : tensor中的每一块做转置

    # find rotation
    U, S, Vt = torch.svd(H.cpu())
    U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device)
    delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
    eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device)
    eye[:, -1, -1] = delta_UV
    R = Vt @ eye @ U.permute(0, 2, 1)
    t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1)
    # warp_A = transform(A, integrate_trans(R,t))
    # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean()
    return integrate_trans(R, t)

def integrate_trans(R, t):
    Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry.
        - R: [3, 3] or [bs, 3, 3], rotation matrix
        - t: [3, 1] or [bs, 3, 1], translation matrix
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    if len(R.shape) == 3:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device)
            trans = np.eye(4)[None]
        trans[:, :3, :3] = R
        trans[:, :3, 3:4] = t.view([-1, 3, 1])
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4).to(R.device)
            trans = np.eye(4)
        trans[:3, :3] = R
        trans[:3, 3:4] = t
    return trans

def mac_fpfh(src_pcd, tgt_pcd,visual=True):
    t0 = time.time()
    # src_kpts = src_pcd.farthest_point_down_sample(500)  # 3-500 points,模型多点,场景可少点
    # tgt_kpts = tgt_pcd.farthest_point_down_sample(500)
    src_kpts = src_pcd.voxel_down_sample(0.007)  # 3-500 points,模型多点,场景可少点
    tgt_kpts = tgt_pcd.voxel_down_sample(0.006)
    print("src 数量:", len(src_kpts.points))
    print("tgt 数量:", len(tgt_kpts.points))
    src_desc = extract_fpfh_features(src_kpts, 0.05)
    tgt_desc = extract_fpfh_features(tgt_kpts, 0.05)
    # o3d.visualization.draw_geometries([o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_desc))])
    # o3d.visualization.draw_geometries([o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tgt_desc))])
    distance = np.sqrt(2 - 2 * (src_desc @ tgt_desc.T) + 1e-6)
    source_idx = np.argmin(distance, axis=1)  # for each row save the index of minimun

    corr = np.concatenate([np.arange(source_idx.shape[0])[:, None], source_idx[:, None]],
                          axis=-1)  # n to 1
    src_pts = np.array(src_kpts.points, dtype=np.float32)[corr[:, 0]]
    tgt_pts = np.array(tgt_kpts.points, dtype=np.float32)[corr[:, 1]]
    src_pts = torch.from_numpy(src_pts).cuda()
    tgt_pts = torch.from_numpy(tgt_pts).cuda()
    t1 = time.time()
    print('1.提取特征FPFH+2.特征匹配 time:', t1 - t0)
    src_dist = ((src_pts[:, None, :] - src_pts[None, :, :]) ** 2).sum(-1) ** 0.5
    tgt_dist = ((tgt_pts[:, None, :] - tgt_pts[None, :, :]) ** 2).sum(-1) ** 0.5
    cross_dis = torch.abs(src_dist - tgt_dist)
    FCG = torch.clamp(1 - cross_dis ** 2 / 0.1 ** 2, min=0)
    FCG = FCG - torch.diag_embed(torch.diag(FCG))
    FCG[FCG < 0.99] = 0
    # SCG = torch.mm(FCG, FCG) * FCG  # 二维矩阵乘法
    SCG = torch.matmul(FCG, FCG) * FCG  # 多维矩阵乘法0.65 s
    SCG = SCG.cpu().numpy()
    graph = igraph.Graph.Adjacency((SCG > 0).tolist())
    t2 = time.time()
    print('3.建图time:', t2 - t1)
    graph.es['weight'] = SCG[SCG.nonzero()]
    graph.vs['label'] = range(0, corr.shape[0])
    macs = graph.maximal_cliques(min=5)   # 参数3, 替换为 c++ 扩展
    t3 = time.time()
    print('4.搜索团time:', t3 - t2)
    # ta = time.time()
    # clique_weight = mac_filter.filter(macs, SCG)
    # tb = time.time()
    # print("cython:",tb-ta)

    clique_weight = np.zeros(len(macs), dtype=float)  # mac:list(tuple(int)),SCG :ndarray, 只要能吧这个写成c++扩展,速度就能上去
    for ind in range(len(macs)):
        mac = list(macs[ind])
        if len(mac) >= 3:
            for i in range(len(mac)):
                for j in range(i + 1, len(mac)):      # 耗时
                    clique_weight[ind] = clique_weight[ind] + SCG[mac[i], mac[j]]
    tc =time.time()

    clique_ind_of_node = np.ones(corr.shape[0], dtype=int) * -1
    max_clique_weight = np.zeros(corr.shape[0], dtype=float)
    max_size = 3
    for ind in range(len(macs)):
        mac = list(macs[ind])
        weight = clique_weight[ind]
        if weight > 0:
            for i in range(len(mac)):
                if weight > max_clique_weight[mac[i]]:
                    max_clique_weight[mac[i]] = weight
                    clique_ind_of_node[mac[i]] = ind
                    max_size = len(mac) > max_size and len(mac) or max_size
    td = time.time()
    print("td:", td - tc)
    filtered_clique_ind = list(set(clique_ind_of_node))
    if -1 in filtered_clique_ind:
    t4 = time.time()
    print('5.后处理time:', t4 - t3)
    group = []
    for s in range(3, max_size + 1):
    for ind in filtered_clique_ind:
        mac = list(macs[ind])
        group[len(mac) - 3].append(ind)

    tensor_list_A = []
    tensor_list_B = []
    for i in range(len(group)):
        if len(group[i]) > 0:
            batch_A = src_pts[list(macs[group[i][0]])][None]
            batch_B = tgt_pts[list(macs[group[i][0]])][None]
            if len(group) == 1:
            for j in range(1, len(group[i])):
                mac = list(macs[group[i][j]])
                src_corr = src_pts[mac][None]
                tgt_corr = tgt_pts[mac][None]
                batch_A = torch.cat((batch_A, src_corr), 0)
                batch_B = torch.cat((batch_B, tgt_corr), 0)
    t5 = time.time()
    print('6.团分组time:', t5 - t4)
    max_score = 0
    for i in range(len(tensor_list_A)):
        trans = rigid_transform_3d(tensor_list_A[i], tensor_list_B[i], None, 0)
        for i in range(len(trans)):
            trans_init_icp = trans[i].cpu().numpy()
            reg_p2ps = o3d.pipelines.registration.registration_icp(
                src_kpts, tgt_kpts, 0.005, trans_init_icp,
                #  TransformationEstimationPointToPlane,TransformationEstimationPointToPoint(),TransformationEstimationForGeneralizedICP,TransformationEstimationForColoredICP

            if reg_p2ps.fitness > max_score:
                max_score = reg_p2ps.fitness
                mac_trans_init_icp = reg_p2ps.transformation
                source_array = np.array(src_pcd.points).astype(np.float32).reshape(1, -1, 3)
                final_poses = mac_trans_init_icp
                # model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
                # model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
                # model_ts_pcd.paint_uniform_color([1, 0, 0])
                # tgt_pcd.paint_uniform_color([0, 0, 1])
                # o3d.visualization.draw_geometries([model_ts_pcd, tgt_pcd])
    print("icp filter:",time.time() - t5)
    # print('--------------------------9.icp可视化结果--------------------------')
    reg_p2p = o3d.pipelines.registration.registration_icp(
        src_kpts, tgt_kpts, 0.005, mac_trans_init_icp,
    print('final icp fitness:', reg_p2p.fitness)
    final_poses = reg_p2p.transformation
    if visual:
        source_array = np.array(src_pcd.points).astype(np.float32).reshape(1, -1, 3)
        model_ts = np.dot(source_array, final_poses[:3, :3].T) + final_poses[:3, 3]
        model_ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
        model_ts_pcd.paint_uniform_color([1, 0, 0])
        tgt_pcd.paint_uniform_color([0, 0, 1])
        o3d.visualization.draw_geometries([model_ts_pcd, tgt_pcd])
    return final_poses
import time
import numpy as np
import torch
import sys
import random
from ransac import extract_fpfh_features
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial import KDTree
from collections import defaultdict
import trimesh.transformations as tf

import open3d as o3d
import pickle

import torch
import numpy as np

class Matcher_plus():
    def __init__(self,
        self.inlier_threshold = inlier_threshold
        self.num_node = num_node
        self.use_mutual = use_mutual
        self.d_thre = d_thre
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.max_points = max_points
        self.nms_radius = nms_radius
        self.k1 = k1
        self.k2 = k2
        self.FS_TCD_thre = FS_TCD_thre
        self.relax_match_num = relax_match_num
        self.NS_by_IC = NS_by_IC

    def pick_seeds(self, dists, scores, R, max_num):
        Select seeding points using Non Maximum Suppression. (here we only support bs=1)
            - dists:       [bs, num_corr, num_corr] src keypoints distance matrix
            - scores:      [bs, num_corr]     initial confidence of each correspondence
            - R:           float              radius of nms
            - max_num:     int                maximum number of returned seeds
            - picked_seeds: [bs, num_seeds]   the index to the seeding correspondences
        assert scores.shape[0] == 1

        # parallel Non Maximum Suppression (more efficient)
        score_relation = scores.T >= scores  # [num_corr, num_corr], save the relation of leading_eig
        # score_relation[dists[0] >= R] = 1  # mask out the non-neighborhood node
        score_relation = score_relation.bool() | (dists[0] >= R).bool()
        is_local_max = score_relation.min(-1)[0].float()

        score_local_max = scores * is_local_max
        sorted_score = torch.argsort(score_local_max, dim=1, descending=True)

        # max_num = scores.shape[1]

        return_idx = sorted_score[:, 0: max_num].detach()

        return return_idx

    def cal_seed_trans(self, seeds, SC2_measure, src_keypts, tgt_keypts):
        Calculate the transformation for each seeding correspondences.
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - SC2_measure: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - seedwise_trans_relax:       [bs, 4, 4]    the relaxed transformation matrix selected by IC
            - final_trans:       [bs, 4, 4]             best transformation matrix selected by IC
        bs, num_corr, num_channels = SC2_measure.shape[0], SC2_measure.shape[1], SC2_measure.shape[2]
        k1 = self.k1
        k2 = self.k2

        if k1 > num_channels:
            k1 = 4
            k2 = 4

        # The first stage consensus set sampling
        # Finding the k1 nearest neighbors around each seed
        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx = sorted_score[:, :, 0: k1]
        sorted_value, _ = torch.sort(SC2_measure, dim=2, descending=True)
        idx_tmp = knn_idx.contiguous().view([bs, -1])
        idx_tmp = idx_tmp[:, :, None]
        idx_tmp = idx_tmp.expand(-1, -1, 3)

        # construct the local SC2 measure of each consensus subset obtained in the first stage.
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])  # [bs, num_seeds, k, 3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])
        src_dist = ((src_knn[:, :, :, None, :] - src_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn[:, :, :, None, :] - tgt_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_SC_measure = (cross_dist < self.d_thre).float()
        local_SC2_measure = torch.matmul(local_hard_SC_measure[:, :, :1, :], local_hard_SC_measure)

        # perform second stage consensus set sampling
        sorted_score = torch.argsort(local_SC2_measure, dim=3, descending=True)
        knn_idx_fine = sorted_score[:, :, :, 0: k2]

        # construct the soft SC2 matrix of the consensus set
        num = knn_idx_fine.shape[1]
        knn_idx_fine = knn_idx_fine.contiguous().view([bs, num, -1])[:, :, :, None]
        knn_idx_fine = knn_idx_fine.expand(-1, -1, -1, 3)
        src_knn_fine = src_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])  # [bs, num_seeds, k, 3]
        tgt_knn_fine = tgt_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])

        src_dist = ((src_knn_fine[:, :, :, None, :] - src_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn_fine[:, :, :, None, :] - tgt_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_measure = (cross_dist < self.d_thre * 2).float()
        local_SC2_measure = torch.matmul(local_hard_measure, local_hard_measure) / k2
        local_SC_measure = torch.clamp(1 - cross_dist ** 2 / self.d_thre ** 2, min=0)
        # local_SC2_measure = local_SC_measure * local_SC2_measure
        local_SC2_measure = local_SC_measure
        local_SC2_measure = local_SC2_measure.view([-1, k2, k2])

        # Power iteratation to get the inlier probability
        local_SC2_measure[:, torch.arange(local_SC2_measure.shape[1]), torch.arange(local_SC2_measure.shape[1])] = 0
        total_weight = self.cal_leading_eigenvector(local_SC2_measure, method='power')
        total_weight = total_weight.view([bs, -1, k2])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        # calculate the transformation by weighted least-squares for each subsets in parallel
        total_weight = total_weight.view([-1, k2])
        src_knn = src_knn_fine
        tgt_knn = tgt_knn_fine
        src_knn, tgt_knn = src_knn.view([-1, k2, 3]), tgt_knn.view([-1, k2, 3])

        # compute the rigid transformation for each seed by the weighted SVD
        seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
        seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        ## ###############################
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        seedwise_fitness = torch.sum((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]

        relax_num = self.NS_by_IC
        if relax_num > seedwise_fitness.shape[1]:
            relax_num = seedwise_fitness.shape[1]

        batch_best_guess_relax, batch_best_guess_relax_idx = torch.topk(seedwise_fitness, relax_num)

        batch_best_guess = seedwise_fitness.argmax(dim=1)
        best_guess_ratio = seedwise_fitness[0, batch_best_guess]
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        seedwise_trans_relax = seedwise_trans.gather(dim=1,
                                                     index=batch_best_guess_relax_idx[:, :, None, None].expand(-1, -1,
                                                                                                               4, 4))
        trans_list = seedwise_trans
        return seedwise_trans_relax, final_trans, trans_list

    def cal_leading_eigenvector(self, M, method='power'):
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
            - solution: [bs, num_corr] leading eigenvector
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                if torch.allclose(leading_eig, leading_eig_last):
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        Calculate the confidence of the spectral matching solution based on spectral analysis.
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
            - confidence
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, it_num, weights=None):
        Perform post refinement using the initial transformation matrix, only adopted during testing.
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [K]
            - weights:       [bs, num_corr]
            - final_trans:   [bs, 4, 4]
        assert initial_trans.shape[0] == 1
        inlier_threshold = 1.2

        # inlier_threshold_list = [self.inlier_threshold] * it_num

        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * it_num
        else:  # for KITTI
            inlier_threshold_list = [1.2] * it_num

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)

            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                # weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier]
        return initial_trans

    def match_pair(self, src_keypts, tgt_keypts, src_features, tgt_features):
        Select the best model from the rough models filtered by IC Metric
            - src_keypts:  [bs, N, 3]   source point cloud
            - tgt_keypts   [bs, M, 3]   target point cloud
            - src_features  [bs, N,C]  the features of source point cloud
            - tgt_features [bs, M, C]  the features of target point cloud
            - src_keypts:  [bs, N, 3]   source point cloud
            - relax_match_points  [1, N, K, 3]  for each source point, we find K target points as the potential correspondences
            - relax_distance [bs, N, K]  feature distance for the relaxed matches
            - src_keypts_corr [bs, N_C, 3]  source points of N_C one-to-one correspondences
            - tgt_keypts_corr [bs, N_C, 3]  target points of N_C one-to-one correspondences

        N_src = src_features.shape[1]
        N_tgt = tgt_features.shape[1]
        # use all point or sample points.
        if self.num_node == 'all':
            src_sel_ind = np.arange(N_src)
            tgt_sel_ind = np.arange(N_tgt)
            # src_sel_ind = np.random.choice(N_src, self.num_node)
            if self.num_node < N_tgt:
                tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
                tgt_sel_ind = np.arange(N_tgt)

            if self.num_node < N_src:
                src_sel_ind = np.random.choice(N_src, self.num_node)
                src_sel_ind = np.arange(N_src)
            # tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
        src_desc = src_features[:, src_sel_ind, :]
        tgt_desc = tgt_features[:, tgt_sel_ind, :]
        src_keypts = src_keypts[:, src_sel_ind, :]
        tgt_keypts = tgt_keypts[:, tgt_sel_ind, :]

        # match points in feature space.
        distance = torch.sqrt(2 - 2 * (src_desc[0] @ tgt_desc[0].T) + 1e-6)
        distance = distance.unsqueeze(0)
        source_idx = torch.argmin(distance[0], dim=1)
        corr = torch.cat([torch.arange(source_idx.shape[0])[:, None].cuda(), source_idx[:, None]], dim=-1)

        # relax_num = distance.shape[1] // 50
        # if relax_num < 100:
        # relax_num = distance.shape[2] // 100
        relax_num = self.relax_match_num
        relax_distance, relax_source_idx = torch.topk(distance, k=relax_num, dim=-1, largest=False)

        relax_source_idx = relax_source_idx.view(relax_source_idx.shape[0], -1)[:, :, None].expand(-1, -1, 3)
        relax_match_points = tgt_keypts.gather(dim=1, index=relax_source_idx).view(relax_source_idx.shape[0], -1,
                                                                                   relax_num, 3)
        # generate correspondences
        src_keypts_corr = src_keypts[:, corr[:, 0]]
        tgt_keypts_corr = tgt_keypts[:, corr[:, 1]]

        return src_keypts, relax_match_points, relax_distance, src_keypts_corr, tgt_keypts_corr

    def select_best_trans(self, seed_trans, src_keypts, relax_match_points, relax_distance, src_keypts_corr,

        Select the best model from the rough models filtered by IC Metric
            - seed_trans:  [bs, N_s^{'}, 4, 4]   the model selected by IC, N_s^{'} is the number of reserverd transformation
            - src_keypts   [bs, N, 3]   the source point cloud
            - relax_match_points  [1, N, K, 3]  for each source point, we find K target points as the potential correspondences
            - relax_distance [bs, N, K]  feature distance for the relaxed matches
            - src_keypts_corr [bs, N_C, 3]  source points of N_C one-to-one correspondences
            - tgt_keypts_corr [bs, N_C, 3]  target points of N_C one-to-one correspondences
            - the best transformation selected by FS-TCD

        seed_num = seed_trans.shape[1]
        # self.inlier_threshold == 0.10: # for 3DMatch

        best_trans = None
        best_fitness = 0

        for i in range(seed_num):
            # 1. refine the transformation by the one-to-one correspondences
            initial_trans = seed_trans[:, i, :, :]
            initial_trans = self.post_refinement(initial_trans, src_keypts_corr, tgt_keypts_corr, 1)

            # 2. use the transformation to project the source point cloud to target point cloud, and find the nearest neighbor
            warped_src_keypts = transform(src_keypts, initial_trans)
            cross_dist = torch.norm((warped_src_keypts[:, :, None, :] - relax_match_points), dim=-1)
            warped_neighbors = (cross_dist <= self.inlier_threshold).float()
            renew_distance = relax_distance + 2 * (cross_dist > self.inlier_threshold * 1.5).float()
            _, mask_min_idx = renew_distance.min(dim=-1)

            # 3. find the correspondences whose alignment error is less than the threshold
            corr = torch.cat([torch.arange(mask_min_idx.shape[1])[:, None].cuda(), mask_min_idx[0][:, None]], dim=-1)
            verify_mask = warped_neighbors
            verify_mask_row = verify_mask.sum(-1) > 0

            # 4. use the spatial consistency to verify the correspondences
            if verify_mask_row.float().sum() > 0:
                verify_mask_row_idx = torch.where(verify_mask_row == True)
                corr_select = corr[verify_mask_row_idx[1]]
                select_relax_match_points = relax_match_points[:, verify_mask_row_idx[1]]
                src_keypts_corr = src_keypts[:, corr_select[:, 0]]
                tgt_keypts_corr = select_relax_match_points.gather(dim=2,
                                                                   index=corr_select[:, 1][None, :, None, None].expand(
                                                                       -1, -1, -1, 3)).squeeze(dim=2)
                src_dist = torch.norm((src_keypts_corr[:, :, None, :] - src_keypts_corr[:, None, :, :]), dim=-1)
                target_dist = torch.norm((tgt_keypts_corr[:, :, None, :] - tgt_keypts_corr[:, None, :, :]), dim=-1)
                corr_compatibility = src_dist - target_dist
                abs_corr_compatibility = torch.abs(corr_compatibility)

                SC_thre = self.FS_TCD_thre
                corr_compatibility_2 = (abs_corr_compatibility < SC_thre).float()
                compatibility_num = torch.sum(corr_compatibility_2, -1)
                renew_fitness = torch.max(compatibility_num)
                renew_fitness = 0

            if renew_fitness > best_fitness:
                best_trans = initial_trans
                best_fitness = renew_fitness

        return best_trans

    def SC2_PCR(self, src_keypts, tgt_keypts):
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - potential_trans_by_IC:   [bs, 4, 4], the best transformation matrix selected by IC metric.
            - best_trans_by_IC:  [bs, N_s^{'} 4, 4], the potential transformation matrix selected by IC metric.
        bs, num_corr = src_keypts.shape[0], tgt_keypts.shape[1]

        # downsample points
        if num_corr > self.max_points:
            src_keypts = src_keypts[:, :self.max_points, :]
            tgt_keypts = tgt_keypts[:, :self.max_points, :]
            num_corr = self.max_points

        # compute cross dist
        src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
        target_dist = torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]), dim=-1)
        cross_dist = torch.abs(src_dist - target_dist)

        # compute first order measure
        SC_dist_thre = self.d_thre
        SC_measure = torch.clamp(1.0 - cross_dist ** 2 / SC_dist_thre ** 2, min=0)
        hard_SC_measure = (cross_dist < SC_dist_thre).float()

        # select reliable seed correspondences
        confidence = self.cal_leading_eigenvector(SC_measure, method='power')
        seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))

        # compute second order measure
        SC2_dist_thre = self.d_thre / 2
        hard_SC_measure_tight = (cross_dist < SC2_dist_thre).float()
        seed_hard_SC_measure = hard_SC_measure.gather(dim=1,
                                                      index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure

        # compute the seed-wise transformations and select the best one
        potential_trans_by_IC, best_trans_by_IC, trans_list = self.cal_seed_trans(seeds, SC2_measure, src_keypts,

        return potential_trans_by_IC, best_trans_by_IC, trans_list

    def estimator(self, src_keypts, tgt_keypts, src_features, tgt_features):
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - src_features: [bs, num_corr, C]
            - tgt_features: [bs, num_corr, C]
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix
            - pred_trans:   [bs, num_corr], the predicted inlier/outlier label (0,1)
            - src_keypts_corr:  [bs, num_corr, 3], the source points in the matched correspondences
            - tgt_keypts_corr:  [bs, num_corr, 3], the target points in the matched correspondences
        # generate coarse correspondences
        src_keypts, relax_match_points, relax_distance, src_keypts_corr, tgt_keypts_corr = self.match_pair(src_keypts,

        # use the proposed SC2-PCR to estimate the rigid transformation
        seedwise_trans, _, trans_list = self.SC2_PCR(src_keypts_corr, tgt_keypts_corr)

        # select_trans = self.select_best_trans(seedwise_trans, src_keypts, relax_match_points,
        #                                       relax_distance, src_keypts_corr, tgt_keypts_corr)
        # pred_trans = self.post_refinement(select_trans, src_keypts_corr, tgt_keypts_corr, 20)
        # frag1_warp = transform(src_keypts_corr, pred_trans)
        # distance = torch.sum((frag1_warp - tgt_keypts_corr) ** 2, dim=-1) ** 0.5
        # pred_labels = (distance < self.inlier_threshold).float()

        return src_keypts_corr, tgt_keypts_corr, trans_list
        # return pred_trans, pred_labels, src_keypts_corr, tgt_keypts_corr,trans_list

def integrate_trans(R, t):
    Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry.
        - R: [3, 3] or [bs, 3, 3], rotation matrix
        - t: [3, 1] or [bs, 3, 1], translation matrix
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    if len(R.shape) == 3:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device)
            trans = np.eye(4)[None]
        trans[:, :3, :3] = R
        trans[:, :3, 3:4] = t.view([-1, 3, 1])
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4).to(R.device)
            trans = np.eye(4)
        trans[:3, :3] = R
        trans[:3, 3:4] = t
    return trans

def transform(pts, trans):
    Applies the SE3 transformations, support torch.Tensor and np.ndarry.  Equation: trans_pts = R @ pts + t
        - pts: [num_pts, 3] or [bs, num_pts, 3], pts to be transformed
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
        - pts: [num_pts, 3] or [bs, num_pts, 3] transformed pts
    if len(pts.shape) == 3:
        trans_pts = trans[:, :3, :3] @ pts.permute(0, 2, 1) + trans[:, :3, 3:4]
        return trans_pts.permute(0, 2, 1)
        trans_pts = trans[:3, :3] @ pts.T + trans[:3, 3:4]
        return trans_pts.T

def rigid_transform_3d(A, B, weights=None, weight_threshold=0):
        - A:       [bs, num_corr, 3], source point cloud
        - B:       [bs, num_corr, 3], target point cloud
        - weights: [bs, num_corr]     weight for each correspondence
        - weight_threshold: float,    clips points with weight below threshold
        - R, t
    bs = A.shape[0]
    if weights is None:
        weights = torch.ones_like(A[:, :, 0])
    weights[weights < weight_threshold] = 0
    # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)

    # find mean of point cloud
    centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)
    centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (
            torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    # construct weight covariance matrix
    Weight = torch.diag_embed(weights)
    H = Am.permute(0, 2, 1) @ Weight @ Bm

    # find rotation
    U, S, Vt = torch.svd(H.cpu())
    U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device)
    delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
    eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device)
    eye[:, -1, -1] = delta_UV
    R = Vt @ eye @ U.permute(0, 2, 1)
    t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1)
    # warp_A = transform(A, integrate_trans(R,t))
    # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean()
    return integrate_trans(R, t)

class Matcher():
    def __init__(self,
        self.inlier_threshold = inlier_threshold
        self.num_node = num_node
        self.use_mutual = use_mutual
        self.d_thre = d_thre
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.max_points = max_points
        self.nms_radius = nms_radius
        self.k1 = k1
        self.k2 = k2

    def pick_seeds(self, dists, scores, R, max_num):
        Select seeding points using Non Maximum Suppression. (here we only support bs=1)
            - dists:       [bs, num_corr, num_corr] src keypoints distance matrix
            - scores:      [bs, num_corr]     initial confidence of each correspondence
            - R:           float              radius of nms
            - max_num:     int                maximum number of returned seeds
            - picked_seeds: [bs, num_seeds]   the index to the seeding correspondences
        assert scores.shape[0] == 1

        # parallel Non Maximum Suppression (more efficient)
        score_relation = scores.T >= scores  # [num_corr, num_corr], save the relation of leading_eig
        # score_relation[dists[0] >= R] = 1  # mask out the non-neighborhood node
        score_relation = score_relation.bool() | (dists[0] >= R).bool()
        is_local_max = score_relation.min(-1)[0].float()

        score_local_max = scores * is_local_max
        sorted_score = torch.argsort(score_local_max, dim=1, descending=True)

        # max_num = scores.shape[1]

        return_idx = sorted_score[:, 0: max_num].detach()

        return return_idx

    def cal_seed_trans(self, seeds, SC2_measure, src_keypts, tgt_keypts):
        Calculate the transformation for each seeding correspondences.
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - SC2_measure: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - final_trans:       [bs, 4, 4]             best transformation matrix (after post refinement) for each batch.
        bs, num_corr, num_channels = SC2_measure.shape[0], SC2_measure.shape[1], SC2_measure.shape[2]
        k1 = self.k1
        k2 = self.k2

        if k1 > num_channels:
            k1 = 4
            k2 = 4

        # The first stage consensus set sampling
        # Finding the k1 nearest neighbors around each seed
        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx = sorted_score[:, :, 0: k1]
        sorted_value, _ = torch.sort(SC2_measure, dim=2, descending=True)
        idx_tmp = knn_idx.contiguous().view([bs, -1])
        idx_tmp = idx_tmp[:, :, None]
        idx_tmp = idx_tmp.expand(-1, -1, 3)

        # construct the local SC2 measure of each consensus subset obtained in the first stage.
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])  # [bs, num_seeds, k, 3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k1, 3])
        src_dist = ((src_knn[:, :, :, None, :] - src_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn[:, :, :, None, :] - tgt_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_SC_measure = (cross_dist < self.d_thre).float()
        local_SC2_measure = torch.matmul(local_hard_SC_measure[:, :, :1, :], local_hard_SC_measure)

        # perform second stage consensus set sampling
        sorted_score = torch.argsort(local_SC2_measure, dim=3, descending=True)
        knn_idx_fine = sorted_score[:, :, :, 0: k2]

        # construct the soft SC2 matrix of the consensus set
        num = knn_idx_fine.shape[1]
        knn_idx_fine = knn_idx_fine.contiguous().view([bs, num, -1])[:, :, :, None]
        knn_idx_fine = knn_idx_fine.expand(-1, -1, -1, 3)
        src_knn_fine = src_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])  # [bs, num_seeds, k, 3]
        tgt_knn_fine = tgt_knn.gather(dim=2, index=knn_idx_fine).view([bs, -1, k2, 3])

        src_dist = ((src_knn_fine[:, :, :, None, :] - src_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn_fine[:, :, :, None, :] - tgt_knn_fine[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        cross_dist = torch.abs(src_dist - tgt_dist)
        local_hard_measure = (cross_dist < self.d_thre * 2).float()
        local_SC2_measure = torch.matmul(local_hard_measure, local_hard_measure) / k2
        local_SC_measure = torch.clamp(1 - cross_dist ** 2 / self.d_thre ** 2, min=0)
        # local_SC2_measure = local_SC_measure * local_SC2_measure
        local_SC2_measure = local_SC_measure
        local_SC2_measure = local_SC2_measure.view([-1, k2, k2])

        # Power iteratation to get the inlier probability
        local_SC2_measure[:, torch.arange(local_SC2_measure.shape[1]), torch.arange(local_SC2_measure.shape[1])] = 0
        total_weight = self.cal_leading_eigenvector(local_SC2_measure, method='power')
        total_weight = total_weight.view([bs, -1, k2])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        # calculate the transformation by weighted least-squares for each subsets in parallel
        total_weight = total_weight.view([-1, k2])
        src_knn = src_knn_fine
        tgt_knn = tgt_knn_fine
        src_knn, tgt_knn = src_knn.view([-1, k2, 3]), tgt_knn.view([-1, k2, 3])

        # compute the rigid transformation for each seed by the weighted SVD
        seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
        seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        seedwise_fitness = torch.sum((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]
        batch_best_guess = seedwise_fitness.argmax(dim=1)
        best_guess_ratio = seedwise_fitness[0, batch_best_guess]
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        final_trans_list = seedwise_trans
        return final_trans_list, final_trans

    def cal_leading_eigenvector(self, M, method='power'):
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
            - solution: [bs, num_corr] leading eigenvector
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                if torch.allclose(leading_eig, leading_eig_last):
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        Calculate the confidence of the spectral matching solution based on spectral analysis.
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
            - confidence
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, it_num, weights=None):
        Perform post refinement using the initial transformation matrix, only adopted during testing.
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [bs, num_corr]
            - final_trans:   [bs, 4, 4]
        assert initial_trans.shape[0] == 1
        inlier_threshold = 1.2

        # inlier_threshold_list = [self.inlier_threshold] * it_num

        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * it_num
        else:  # for KITTI
            inlier_threshold_list = [1.2] * it_num

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)

            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                # weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier]
        return initial_trans

    def match_pair(self, src_keypts, tgt_keypts, src_features, tgt_features):
        N_src = src_features.shape[1]
        N_tgt = tgt_features.shape[1]
        # use all point or sample points.
        if self.num_node == 'all':
            src_sel_ind = np.arange(N_src)
            tgt_sel_ind = np.arange(N_tgt)
            src_sel_ind = np.random.choice(N_src, self.num_node)
            tgt_sel_ind = np.random.choice(N_tgt, self.num_node)
        src_desc = src_features[:, src_sel_ind, :]
        tgt_desc = tgt_features[:, tgt_sel_ind, :]
        src_keypts = src_keypts[:, src_sel_ind, :]
        tgt_keypts = tgt_keypts[:, tgt_sel_ind, :]

        # match points in feature space.
        distance = torch.sqrt(2 - 2 * (src_desc[0] @ tgt_desc[0].T) + 1e-6)
        # distance = torch.abs(src_desc[0] @ tgt_desc[0].T)  # 内积衡量两特征相似性
        distance = distance.unsqueeze(0)
        source_idx = torch.argmin(distance[0], dim=1)  # source和tat最相似的index
        corr = torch.cat([torch.arange(source_idx.shape[0])[:, None].cuda(), source_idx[:, None]], dim=-1)
        print("src-tgt的点的相关:", corr)
        # generate correspondences
        src_keypts_corr = src_keypts[:, corr[:, 0]]
        tgt_keypts_corr = tgt_keypts[:, corr[:, 1]]

        return src_keypts_corr, tgt_keypts_corr

    def SC2_PCR(self, src_keypts, tgt_keypts):
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix.
            - pred_labels:  [bs, num_corr], the predicted inlier/outlier label (0,1), for classification loss calculation.
        bs, num_corr = src_keypts.shape[0], tgt_keypts.shape[1]

        # downsample points
        if num_corr > self.max_points:
            src_keypts = src_keypts[:, :self.max_points, :]
            tgt_keypts = tgt_keypts[:, :self.max_points, :]
            num_corr = self.max_points

        # compute cross dist
        src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
        target_dist = torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]), dim=-1)
        cross_dist = torch.abs(src_dist - target_dist)

        # compute first order measure
        SC_dist_thre = self.d_thre
        SC_measure = torch.clamp(1.0 - cross_dist ** 2 / SC_dist_thre ** 2, min=0)
        hard_SC_measure = (cross_dist < SC_dist_thre).float()

        # select reliable seed correspondences
        confidence = self.cal_leading_eigenvector(SC_measure, method='power')
        seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))

        # compute second order measure
        SC2_dist_thre = self.d_thre / 2
        hard_SC_measure_tight = (cross_dist < SC2_dist_thre).float()
        seed_hard_SC_measure = hard_SC_measure.gather(dim=1,
                                                      index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure

        # compute the seed-wise transformations and select the best one
        final_trans_list, final_trans = self.cal_seed_trans(seeds, SC2_measure, src_keypts, tgt_keypts)
        final_trans_list = final_trans_list[0]
        # refine the result by recomputing the transformation over the whole set
        final_trans_refine_list = []
        for i in range(len(final_trans_list)):
            final_trans = final_trans_list[i].reshape(1, 4, 4)
            refine_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts, 20)
        return final_trans_list, final_trans_refine_list, final_trans

    def estimator(self, src_keypts, tgt_keypts, src_features, tgt_features):
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - src_features: [bs, num_corr, C]
            - tgt_features: [bs, num_corr, C]
            - pred_trans:   [bs, 4, 4], the predicted transformation matrix
            - pred_trans:   [bs, num_corr], the predicted inlier/outlier label (0,1)
            - src_keypts_corr:  [bs, num_corr, 3], the source points in the matched correspondences
            - tgt_keypts_corr:  [bs, num_corr, 3], the target points in the matched correspondences
        #################################src_keypts_corr = {Tensor: (1, 1013, 3)} tensor([[[ 0.0136,  0.0483, -0.0135],\n         [ 0.0240, -0.0507,  0.0015],\n         [ 0.0271, -0.0508,  0.0016],\n         ...,\n         [ 0.0135,  0.0030, -0.0171],\n         [-0.0215,  0.0080, -0.0170],\n         [-0.0065,  0.0080, -0.0166]]], device='cuda... View
        # generate coarse correspondences
        src_keypts_corr, tgt_keypts_corr = self.match_pair(src_keypts, tgt_keypts, src_features, tgt_features)

        # use the proposed SC2-PCR to estimate the rigid transformation
        final_trans_list, final_trans_refine_list, final_trans = self.SC2_PCR(src_keypts_corr, tgt_keypts_corr)
        print("final_trans_list:", final_trans_list)
        print("final_trans_refine_list:", final_trans_refine_list)
        # frag1_warp = transform(src_keypts_corr, pred_trans)
        # distance = torch.sum((frag1_warp - tgt_keypts_corr) ** 2, dim=-1) ** 0.5
        # pred_labels = (distance < self.inlier_threshold).float()

        return final_trans_list, final_trans_refine_list, final_trans
        # return pred_trans, pred_labels, src_keypts_corr, tgt_keypts_corr

def set_seed(seed=51):
    Set the random seed for reproduce the results.
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_3D_bbox(pcld_ary, small=False):
    min_x, max_x = pcld_ary[:, 0].min(), pcld_ary[:, 0].max()
    min_y, max_y = pcld_ary[:, 1].min(), pcld_ary[:, 1].max()
    min_z, max_z = pcld_ary[:, 2].min(), pcld_ary[:, 2].max()
    bbox = np.array([
        [min_x, min_y, min_z],
        [min_x, min_y, max_z],
        [min_x, max_y, min_z],
        [min_x, max_y, max_z],
        [max_x, min_y, min_z],
        [max_x, min_y, max_z],
        [max_x, max_y, min_z],
        [max_x, max_y, max_z],
    if small:
        center = np.mean(bbox, 0)
        bbox = (bbox - center[None, :]) * 2.0 / 3.0 + center[None, :]
    return bbox

def bernstein(vala, valb):
    """thanks to special sauce https://stackoverflow.com/a/34006336/10059727"""
    h = 1009
    h = h * 9176 + vala
    h = h * 9176 + valb
    return h

def average_rotations(rotations):
    """thanks to jonathan https://stackoverflow.com/a/27410865/10059727"""
    Q = np.zeros((4, len(rotations)))

    for i, rot in enumerate(rotations):
        quat = tf.quaternion_from_matrix(rot)
        Q[:, i] = quat

    _, v = np.linalg.eigh(Q @ Q.T)
    quat_avg = v[:, -1]

    return tf.quaternion_matrix(quat_avg)

def cluster_poses(poses, dist_max=0.5, rot_max_deg=10, pdist_rot=None):
    rots = np.array([T_m2s[:3, :3] for T_m2s, _, _ in poses])
    locs = np.array([T_m2s[:3, 3] for T_m2s, _, _ in poses])
    scores = np.array([score for _, _, score in poses])

    method = "centroid"

    # 1) cluster by location
    dist_dists = pdist(locs)
    dist_dendro = linkage(dist_dists, method)
    dist_clusters = fcluster(dist_dendro, dist_max, criterion="distance")

    # 2) cluster by rotations
    # XXX optimize, we can make more smaller cluster problems, since
    # a cluster across distant poses doesn't make sense
    rot_dists = pdist_rot(rots)
    rot_dendor = linkage(rot_dists, method)
    rot_clusters = fcluster(rot_dendor, rot_max_deg, criterion="distance")

    # Combine the two clusterings, by creating new clusters
    # if two poses are in same cluster in loc and rot, they will be in new
    # common cluster (hash of both cluster ids)
    pose_clusters = bernstein(dist_clusters, rot_clusters)

    # remap the ludicrous hash values to range 0..num
    _, pose_clusters = np.unique(pose_clusters, return_inverse=True)

    cluster_scores = np.zeros(np.max(pose_clusters) + 1)
    for pose_score, pose_cluster in zip(scores, pose_clusters):
        cluster_scores[pose_cluster] += pose_score

    best_cluster_idx = np.argmax(cluster_scores)
        "Best cluster",
        np.count_nonzero(pose_clusters == best_cluster_idx),

    # plt.hist(cluster_scores, histtype="stepfilled", bins=100)
    # plt.title("Cluster Scores Histogram")
    # plt.show()

    out_ts = defaultdict(list)
    out_Rs = defaultdict(list)
    for pose_idx, cluster_idx in enumerate(pose_clusters):

    sorted_clusters = np.argsort(cluster_scores)[::-1]

    for idx_cluster, top_cluster in enumerate(sorted_clusters[:10]):
        print(f"cluster {idx_cluster} contains {len(out_ts[top_cluster])} poses")

    geo = lambda x, y: np.sqrt(x * y)

    best_cluster_idx = np.argmax(cluster_scores)
    best_score = cluster_scores[best_cluster_idx]
    best_geo_score = geo(best_score, len(out_ts[best_cluster_idx]))
    best_rel_thresh = 0.5

    out_poses = []
    for cluster_idx in sorted_clusters:
        geo_score = geo(len(out_ts[cluster_idx]), cluster_scores[cluster_idx])
        if geo_score < best_rel_thresh * best_geo_score:

        print("cluster idx", cluster_idx, cluster_scores[cluster_idx], "geoscore", geo_score)
        ts = out_ts[cluster_idx]
        Rs = out_Rs[cluster_idx]

        avg_t = np.mean(ts, axis=0)
        avg_R = average_rotations(Rs)
        out_T = np.eye(4)
        out_T[:3, :3] = avg_R[:3, :3]
        out_T[:3, 3] = avg_t
        out_poses.append((out_T, 0, geo_score))

    print("Returning", len(out_poses), "clustered and averaged poses")
    return out_poses

def paxini_sc2_plus(model_down_pcd, scene_down_pcd, inlier_threshold, num_node,
                    use_mutual, d_thre, num_iterations, ratio, nms_radius, max_points, k1, k2, visual=True):
    t0 = time.time()
    model_down_pcd = model_down_pcd.farthest_point_down_sample(1000)  # 3-500 points,模型多点,场景可少点
    # model_down_pcd = model_down_pcd.voxel_down_sample(0.004)  # 3-500 points,模型多点,场景可少点
    scene_down_pcd = scene_down_pcd.voxel_down_sample(0.0025)
    # print("model数量:", len(model_down_pcd.points))
    # print("scene数量:", len(scene_down_pcd.points))
    model_down_array = np.array(model_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    scene_down_array = np.array(scene_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    model_down_tensor = torch.from_numpy(model_down_array).cuda()
    scene_down_tensor = torch.from_numpy(scene_down_array).cuda()
    # print('---------------2.compute pcd local features-----------------------------------------')
    src_features = extract_fpfh_features(model_down_pcd, 0.05)
    tgt_features = extract_fpfh_features(scene_down_pcd, 0.05)
    src_features = src_features.reshape(1, -1, 33)
    tgt_features = tgt_features.reshape(1, -1, 33)
    src_features_tensor = torch.from_numpy(src_features).cuda()
    tgt_features_tensor = torch.from_numpy(tgt_features).cuda()
    # print('---------------3. match descriptor && compute rigid transformation----')
    matcher_plus = Matcher_plus(inlier_threshold=inlier_threshold, num_node=num_node, use_mutual=use_mutual,
                                d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                                max_points=max_points, k1=k1, k2=k2)
    src_keypts_corr, tgt_keypts_corr, trans_list = matcher_plus.estimator(model_down_tensor, scene_down_tensor,
                                                                          src_features_tensor, tgt_features_tensor)
    # print("------------------------4.icp filter -----------------------------------------")
    final_poses_list = trans_list.detach().cpu().numpy()[0]  # pose 聚类?
    # print('final_poses_list:', len(final_poses_list))
    # pose cluster:
    # out_poses = cluster_poses(final_poses_list, dist_max=0.5, rot_max_deg=10, pdist_rot=None)
    # print('out_poses:', len(out_poses))
    max_fitness = -1
    # min_rmse = 10
    for i in range(len(final_poses_list)):
        pose_init = final_poses_list[i]
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.001, pose_init,
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            # min_rmse = reg_p2p.inlier_rmse
            sc2_final_pose = reg_p2p.transformation
            # print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            # print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)

    reg_p2ps = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.01, sc2_final_pose,
    sc2_final_pose = reg_p2ps.transformation
    print("sc2 运行时间:", time.time() - t0)
    if visual:
        model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
        ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
        ts_pcd.paint_uniform_color([1, 0, 0])
        scene_down_pcd.paint_uniform_color([0, 0, 1])
        o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd])
    return sc2_final_pose

def paxini_sc2(model_down_pcd0, scene_down_pcd0, inlier_threshold, num_node,
               use_mutual, d_thre, num_iterations, ratio, nms_radius,
               max_points, k1, k2):
    model_down_pcd = model_down_pcd0.voxel_down_sample(0.0025)  # 3-500 points,模型多点,场景可少点
    scene_down_pcd = scene_down_pcd0.voxel_down_sample(0.0025)
    print("model数量:", len(model_down_pcd.points))
    print("scene数量:", len(scene_down_pcd.points))
    model_down_array = np.array(model_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    scene_down_array = np.array(scene_down_pcd.points).astype(np.float32).reshape(1, -1, 3)
    model_down_tensor = torch.from_numpy(model_down_array).cuda()
    scene_down_tensor = torch.from_numpy(scene_down_array).cuda()
    print('---------------2.compute pcd local features-----------------------------------------')
    src_features = extract_fpfh_features(model_down_pcd, 0.05)
    tgt_features = extract_fpfh_features(scene_down_pcd, 0.05)
    src_features = src_features.reshape(1, -1, 33)
    tgt_features = tgt_features.reshape(1, -1, 33)
    src_features_tensor = torch.from_numpy(src_features).cuda()
    tgt_features_tensor = torch.from_numpy(tgt_features).cuda()

    print('---------------3. match descriptor && compute rigid transformation----')
    matcher = Matcher(inlier_threshold=inlier_threshold, num_node=num_node, use_mutual=use_mutual,
                      d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                      max_points=max_points, k1=k1, k2=k2)
    final_trans_list, final_trans_refine_list, final_trans = matcher.estimator(model_down_tensor, scene_down_tensor,
    reg_p = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.02, final_trans[0].cpu().numpy(),
    sc2_final_pose = reg_p.transformation
    model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
    ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    ts_pcd.paint_uniform_color([1, 0, 0])
    scene_down_pcd0.paint_uniform_color([0, 0, 1])
    o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    print("------------------------4.icp filter -----------------------------------------")
    final_poses_list = final_trans_list.detach().cpu().numpy()
    print('final_poses_list:', len(final_poses_list))
    max_fitness = -1
    min_rmse = 10
    for i in range(len(final_poses_list)):
        pose_init = final_poses_list[i]
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.05, pose_init,
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1000))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            min_rmse = reg_p2p.inlier_rmse
            sc2_final_pose = reg_p2p.transformation
            print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)
            model_ts = np.dot(model_down_array, sc2_final_pose[:3, :3].T) + sc2_final_pose[:3, 3]
            ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
            ts_pcd.paint_uniform_color([1, 0, 0])
            scene_down_pcd0.paint_uniform_color([0, 0, 1])
            o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    min_rmse = 10
    for i in range(len(final_trans_refine_list)):
        pose_init = final_trans_refine_list[i].cpu().numpy().reshape(4, 4)
        reg_p2p = o3d.pipelines.registration.registration_icp(
            model_down_pcd, scene_down_pcd, 0.02, pose_init,
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=1))  # iterations (30 by default)
        if reg_p2p.fitness > max_fitness:
            # if reg_p2p.inlier_rmse < min_rmse:
            max_fitness = reg_p2p.fitness
            min_rmse = reg_p2p.inlier_rmse
            sc2_final_refine_pose = reg_p2p.transformation
            print(f'{i}:reg_p2p.fitness:', reg_p2p.fitness)
            print(f'{i}:reg_p2p.inliner_rmse:', reg_p2p.inlier_rmse)
            # print(f'{i}:reg_p2p.transformation:', reg_p2p.transformation)
            model_ts = np.dot(model_down_array, sc2_final_refine_pose[:3, :3].T) + sc2_final_refine_pose[:3, 3]
            ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
            ts_pcd.paint_uniform_color([1, 0, 0])
            scene_down_pcd0.paint_uniform_color([0, 0, 1])
            o3d.visualization.draw_geometries([ts_pcd, scene_down_pcd0])
    print('------------------------5.icp refine -----------------------------------------')
    result_sc2_icp_final = o3d.pipelines.registration.registration_icp(
        model_down_pcd, scene_down_pcd, 0.02, sc2_final_pose,
    final_poses = result_sc2_icp_final.transformation
    confidence = result_sc2_icp_final.fitness
    print("confidence:", confidence)
    # model_ts = np.dot(model_down_array, final_poses[:3, :3].T) + final_poses[:3, 3]
    # ts_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(model_ts.reshape(-1, 3)))
    # o3d.visualization.draw_geometries([ts_pcd,scene_down_pcd0])           # 点云太对会画不出来
    return final_poses

if __name__ == '__main__':
    t0 = time.time()
    inlier_threshold = 0.001
    num_node = 'all'
    use_mutual = True
    num_iterations = 10
    ratio = 0.2
    nms_radius = 0.1
    max_points = 8000
    k1 = 30  # Finding the k1 nearest neighbors around each seed
    k2 = 20
    voxel_size = 0.001  # 下采样比例
    d_thre = voxel_size * 2

    print('---------------1.load pcd---------------------------------------------')
    model_pcd_path = r'D:\codes_pose\PPF_Project\3D-Registration-with-Maximal-Cliques\Python_implement\test\carton_02.ply'  # soap_01,,coconut_00,markers_03
    pth = open(r'D:\codes_pose\PPF_Project\ppf-original\model-globally-match-locally-python\example_models'
               'rb')  # soap_01--0000pkl, coconut_00--0002.pickle; carton_02--0134.pkl,marker_pose-200pkl
    model_pcd = o3d.io.read_point_cloud(model_pcd_path)
    model_array = np.hstack((np.array(model_pcd.points), np.array(model_pcd.normals))).astype(np.float32)
    # obj_pose  = np.array([[-0.7287485250189922, -0.6846525505303752, -0.013284289024821653, -0.054847823497723394],  #soap_pose
    #                      [-0.662871399507191, 0.7101669367340677, -0.23720124300923914, -0.04771079338428959],
    #                      [0.17183449885869845, -0.16405428071831307, -0.9713697020084459, 0.5589560143247371],
    #                      [0, 0, 0, 1]])
    obj_pose = np.array(
        [[-0.006038161573090156, -0.9989707481405733, -0.04495536635686892, -0.06520650927334047],  # carton_pose
         [-0.9191053187973984, 0.02325571354051051, -0.393325037019072, -0.06357232357506885],
         [0.3939656756154652, 0.038943756202820916, -0.9182997496948725, 0.4651729732704407],
         [0, 0, 0, 1]])
    # coconut_pose = np.array(
    #     [[0.6237558392913077, -0.7806701285279535, -0.038507186006563776, 0.1410553384970578],  # coconut
    #      [-0.7447691171729245, -0.5786779007489092, -0.33234146489762334, -0.11014853879504719],
    #      [0.2371657965547699, 0.2359788922954422, -0.9423727220880057, 0.5867524362173748],
    #      [0, 0, 0, 1]])
    #  = np.array([[-0.1876610060907363, -0.9820621511025337, -0.0183651344911663, -0.0641744048368183],  # marker_pose
    #                      [-0.9813082289108378, 0.18826291187839725, -0.039890298105449536, -0.09830046364764121],
    #                      [0.042632225661904416, 0.010536004125507674, -0.9990352776314657, 0.44157644191646106],
    #                      [0, 0, 0, 1]])
    # src_pcd.transform(obj_pose)
    depth = np.array(pickle.load(pth))
    depth_array = np.array(depth)
    depth = o3d.geometry.Image(depth_array)
    intrinsic = o3d.camera.PinholeCameraIntrinsic(width=depth_array.shape[1], height=depth_array.shape[0],
                                                  fx=616.58, fy=616.778, cx=323.103, cy=238.464)
    tgt_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth, intrinsic, depth_scale=1000,
                                                              depth_trunc=1000.0)  # depth_scale=1000后单位也是米
    print('obj_pose:', obj_pose)
    marker_r = np.array(obj_pose[:3, :3])  # 物体相对相机的pose
    marker_t = np.array(obj_pose[:3, 3])
    bbox = get_3D_bbox(model_array)
    bbox *= 1.
    transformed_bbox = np.dot(bbox, marker_r.T) + marker_t
    cropped_pcd = tgt_pcd.crop(
        o3d.geometry.AxisAlignedBoundingBox(min_bound=np.min(transformed_bbox, axis=0),
                                            max_bound=np.max(transformed_bbox, axis=0)))
    cropped_pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))

    model_pcd = model_pcd.transform(obj_pose)  # 直接先变换
    o3d.visualization.draw_geometries([model_pcd, cropped_pcd])

    model_down_pcd = model_pcd.voxel_down_sample(voxel_size)
    scene_down_pcd = cropped_pcd.voxel_down_sample(voxel_size)
    print("model数量:", len(model_down_pcd.points))
    print("scene数量:", len(scene_down_pcd.points))

    final_poses = paxini_sc2(model_down_pcd, scene_down_pcd, inlier_threshold=inlier_threshold, num_node=num_node,
                             d_thre=d_thre, num_iterations=num_iterations, ratio=ratio, nms_radius=nms_radius,
                             max_points=max_points, k1=k1, k2=k2)




https://github.com/zhangxy0517/3D-Registration-with-Maximal-Cliques c++

