PointNet++分割结果可视化程序

本文参考了http://t.csdnimg.cn/6ohlf这篇博客,对此作者非常感谢,那篇博客是单个点云的分割可视化,但输入的是两幅图像:一幅深度图,一幅彩色图,由此得到点云。

本文稍加改动,输入的是一幅点云,这里例子采用的是输入txt,当然也可以用别的格式吧。

另外记得对那篇博客的代码中的一些地方做一下改动,本文对要改动的地方也注释了一下。

说得有点粗糙,上代码了。

import tqdm
import matplotlib
import torch
import os
import warnings
import numpy as np
import open3d as o3d
from torch.utils.data import Dataset
# import pybullet as p
from models.pointnet2_part_seg_msg import get_model as pointnet2
import time

warnings.filterwarnings('ignore')
matplotlib.use("Agg")


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc, centroid, m



class PartNormalDataset(Dataset):
    def __init__(self, point_cloud, npoints=2500, normal_channel=False):
        self.npoints = npoints  # 采样点数
        self.cat = {}
        self.normal_channel = normal_channel  # 是否使用法向信息

        position_data = np.asarray(point_cloud.points)
        normal_data = np.asarray(point_cloud.normals)
        self.raw_pcd = np.hstack([position_data, normal_data]).astype(np.float32)

        # 分类写一下
        self.cat = {'Airplane': '02691156'}
        # 输出的是元组,('Airplane',123.txt)
        #下面self.classe中的数字是重点,数字代表所选择的类别,比如0是飞机,2是帽子等等
        self.classes = {'Airplane': 0}

        data = self.raw_pcd

        if not self.normal_channel:  # 判断是否使用法向信息
            self.point_set = data[:, 0:3]
        else:
            self.point_set = data[:, 0:6]

        self.point_set[:, 0:3], self.centroid, self.m = pc_normalize(self.point_set[:, 0:3])  # 做一个归一化

        choice = np.random.choice(self.point_set.shape[0], self.npoints, replace=True)  # 对一个类别中的数据进行随机采样 返回索引,允许重复采样
        # resample
        self.point_set = self.point_set[choice, :]  # 根据索引采样

    def __getitem__(self, index):

        cat = list(self.cat.keys())[0]
        cls = self.classes[cat]  # 将类名转换为索引
        cls = np.array([cls]).astype(np.int32)

        return self.point_set, cls, self.centroid, self.m  # pointset是点云数据,cls十六个大类别,seg是一个数据中,不同点对应的小类别

    def __len__(self):
        return 1


class Generate_txt_and_3d_img:
    def __init__(self, num_classes, testDataLoader, model, visualize=False):
        self.testDataLoader = testDataLoader
        self.num_classes = num_classes
        self.heat_map = False  # 控制是否输出heatmap
        self.visualize = visualize  # 是否open3d可视化
        self.model = model

        self.generate_predict()
        self.o3d_draw_3d_img()

    def __getitem__(self, index):
        return self.predict_pcd_colored

    def generate_predict(self):

        for _, (points, label, centroid, m) in tqdm.tqdm(enumerate(self.testDataLoader),
                                                         total=len(self.testDataLoader), smoothing=0.9):

            # 点云数据、整个图像的标签、每个点的标签、  没有归一化的点云数据(带标签)torch.Size([1, 7, 2048])
            points = points.transpose(2, 1)
            # print('1',target.shape) # 1 torch.Size([1, 2048])
            xyz_feature_point = points[:, :6, :]

            model = self.model
            # 下面这行注意改一下,用shapenet的话填16
            seg_pred, _ = model(points, self.to_categorical(label, 16))
            seg_pred = seg_pred.cpu().data.numpy()

            if self.heat_map:
                out = np.asarray(np.sum(seg_pred, axis=2))
                seg_pred = ((out - np.min(out) / (np.max(out) - np.min(out))))
            else:
                seg_pred = np.argmax(seg_pred, axis=-1)  # 获得网络的预测结果 b n c

            seg_pred = np.concatenate([np.asarray(xyz_feature_point), seg_pred[:, None, :]],
                                      axis=1).transpose((0, 2, 1)).squeeze(0)

            self.predict_pcd = seg_pred
            self.centroid = centroid
            self.m = m

    def o3d_draw_3d_img(self):

        pcd = self.predict_pcd
        pcd_vector = o3d.geometry.PointCloud()
        # 加载点坐标
        pcd_vector.points = o3d.utility.Vector3dVector(self.m * pcd[:, :3] + self.centroid)
        # colors = np.random.randint(255, size=(2,3))/255
        # colors这地方改了一下,注意设置多一些,避免分割多个部分颜色不够
        colors = np.array([[0.8, 0.8, 0.8], [1, 0, 0], [0.3, 0.3, 0.3],[0, 1, 0], [0, 0, 1], [0.5, 0.5, 1]])
        pcd_vector.colors = o3d.utility.Vector3dVector(colors[list(map(int, pcd[:, 6])), :])

        if self.visualize:
            coord_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
            o3d.visualization.draw_geometries([pcd_vector, coord_mesh])
        self.predict_pcd_colored = pcd_vector

    def to_categorical(self, y, num_classes):
        """ 1-hot encodes a tensor """
        new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
        if (y.is_cuda):
            return new_y.cuda()
        return new_y


def load_models(model_dict={'PonintNet': [pointnet2(num_classes=50, normal_channel=True).eval(),
                                          r'./log/part_seg/pointnet2_part_seg_msg/checkpoints']}):
    model = list(model_dict.values())[0][0]
    checkpoints_dir = list(model_dict.values())[0][1]
    weight_dict = torch.load(os.path.join(checkpoints_dir, 'best_model.pth'))
    model.load_state_dict(weight_dict['model_state_dict'])
    return model


class Open3dVisualizer():

    def __init__(self):

        self.point_cloud = o3d.geometry.PointCloud()
        self.o3d_started = False

        self.vis = o3d.visualization.VisualizerWithKeyCallback()
        self.vis.create_window()

    def __call__(self, points, colors):

        self.update(points, colors)

        return False

    def update(self, points, colors):
        coord_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.15, origin=[0, 0, 0])
        self.point_cloud.points = points
        self.point_cloud.colors = colors
        # self.point_cloud.transform([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
        # self.vis.clear_geometries()
        # Add geometries if it is the first time
        if not self.o3d_started:
            self.vis.add_geometry(self.point_cloud)
            self.vis.add_geometry(coord_mesh)
            self.o3d_started = True

        else:
            self.vis.update_geometry(self.point_cloud)
            self.vis.update_geometry(coord_mesh)

        self.vis.poll_events()
        self.vis.update_renderer()


if __name__ == '__main__':
    #这地方改一下
    num_classes = 50  # 填写数据集的类别数 如果是s3dis这里就填13   shapenet这里就填50

    # color_image = o3d.io.read_image('image/rgb1.jpg')
    # depth_image = o3d.io.read_image('image/depth1.png')
    txt_path = './myowncloud/666.txt'
    # 通过numpy读取txt点云
    pcd = np.genfromtxt(txt_path, delimiter=" ")

    point_cloud = o3d.geometry.PointCloud()
    # 加载点坐标
    # txt点云前三个数值一般对应x、y、z坐标,可以通过open3d.geometry.PointCloud().points加载
    # 如果有法线或颜色,那么可以分别通过open3d.geometry.PointCloud().normals或open3d.geometry.PointCloud().colors加载
    point_cloud.points = o3d.utility.Vector3dVector(pcd[:, :3])
    #point_cloud.normals = o3d.utility.Vector3dVector(pcd[:, 3:6])

point_cloud.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.03, max_nn=30))
    print(np.asarray(point_cloud.points))
    print(np.asarray(point_cloud.normals))
    TEST_DATASET = PartNormalDataset(point_cloud, 30000, True)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
                                                 drop_last=True)
    predict_pcd = Generate_txt_and_3d_img(num_classes, testDataLoader, load_models(), visualize=True)

应该对吧,反正我运行出结果了。这个是把飞机的那个模型的点云抠出来一部分,分割不是特别完美,不过显示已经没问题了。

  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值