本文参考了http://t.csdnimg.cn/6ohlf这篇博客,对此作者非常感谢,那篇博客是单个点云的分割可视化,但输入的是两幅图像:一幅深度图,一幅彩色图,由此得到点云。
本文稍加改动,输入的是一幅点云,这里例子采用的是输入txt,当然也可以用别的格式吧。
另外记得对那篇博客的代码中的一些地方做一下改动,本文对要改动的地方也注释了一下。
说得有点粗糙,上代码了。
import tqdm
import matplotlib
import torch
import os
import warnings
import numpy as np
import open3d as o3d
from torch.utils.data import Dataset
# import pybullet as p
from models.pointnet2_part_seg_msg import get_model as pointnet2
import time
warnings.filterwarnings('ignore')
matplotlib.use("Agg")
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc, centroid, m
class PartNormalDataset(Dataset):
def __init__(self, point_cloud, npoints=2500, normal_channel=False):
self.npoints = npoints # 采样点数
self.cat = {}
self.normal_channel = normal_channel # 是否使用法向信息
position_data = np.asarray(point_cloud.points)
normal_data = np.asarray(point_cloud.normals)
self.raw_pcd = np.hstack([position_data, normal_data]).astype(np.float32)
# 分类写一下
self.cat = {'Airplane': '02691156'}
# 输出的是元组,('Airplane',123.txt)
#下面self.classe中的数字是重点,数字代表所选择的类别,比如0是飞机,2是帽子等等
self.classes = {'Airplane': 0}
data = self.raw_pcd
if not self.normal_channel: # 判断是否使用法向信息
self.point_set = data[:, 0:3]
else:
self.point_set = data[:, 0:6]
self.point_set[:, 0:3], self.centroid, self.m = pc_normalize(self.point_set[:, 0:3]) # 做一个归一化
choice = np.random.choice(self.point_set.shape[0], self.npoints, replace=True) # 对一个类别中的数据进行随机采样 返回索引,允许重复采样
# resample
self.point_set = self.point_set[choice, :] # 根据索引采样
def __getitem__(self, index):
cat = list(self.cat.keys())[0]
cls = self.classes[cat] # 将类名转换为索引
cls = np.array([cls]).astype(np.int32)
return self.point_set, cls, self.centroid, self.m # pointset是点云数据,cls十六个大类别,seg是一个数据中,不同点对应的小类别
def __len__(self):
return 1
class Generate_txt_and_3d_img:
def __init__(self, num_classes, testDataLoader, model, visualize=False):
self.testDataLoader = testDataLoader
self.num_classes = num_classes
self.heat_map = False # 控制是否输出heatmap
self.visualize = visualize # 是否open3d可视化
self.model = model
self.generate_predict()
self.o3d_draw_3d_img()
def __getitem__(self, index):
return self.predict_pcd_colored
def generate_predict(self):
for _, (points, label, centroid, m) in tqdm.tqdm(enumerate(self.testDataLoader),
total=len(self.testDataLoader), smoothing=0.9):
# 点云数据、整个图像的标签、每个点的标签、 没有归一化的点云数据(带标签)torch.Size([1, 7, 2048])
points = points.transpose(2, 1)
# print('1',target.shape) # 1 torch.Size([1, 2048])
xyz_feature_point = points[:, :6, :]
model = self.model
# 下面这行注意改一下,用shapenet的话填16
seg_pred, _ = model(points, self.to_categorical(label, 16))
seg_pred = seg_pred.cpu().data.numpy()
if self.heat_map:
out = np.asarray(np.sum(seg_pred, axis=2))
seg_pred = ((out - np.min(out) / (np.max(out) - np.min(out))))
else:
seg_pred = np.argmax(seg_pred, axis=-1) # 获得网络的预测结果 b n c
seg_pred = np.concatenate([np.asarray(xyz_feature_point), seg_pred[:, None, :]],
axis=1).transpose((0, 2, 1)).squeeze(0)
self.predict_pcd = seg_pred
self.centroid = centroid
self.m = m
def o3d_draw_3d_img(self):
pcd = self.predict_pcd
pcd_vector = o3d.geometry.PointCloud()
# 加载点坐标
pcd_vector.points = o3d.utility.Vector3dVector(self.m * pcd[:, :3] + self.centroid)
# colors = np.random.randint(255, size=(2,3))/255
# colors这地方改了一下,注意设置多一些,避免分割多个部分颜色不够
colors = np.array([[0.8, 0.8, 0.8], [1, 0, 0], [0.3, 0.3, 0.3],[0, 1, 0], [0, 0, 1], [0.5, 0.5, 1]])
pcd_vector.colors = o3d.utility.Vector3dVector(colors[list(map(int, pcd[:, 6])), :])
if self.visualize:
coord_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
o3d.visualization.draw_geometries([pcd_vector, coord_mesh])
self.predict_pcd_colored = pcd_vector
def to_categorical(self, y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
def load_models(model_dict={'PonintNet': [pointnet2(num_classes=50, normal_channel=True).eval(),
r'./log/part_seg/pointnet2_part_seg_msg/checkpoints']}):
model = list(model_dict.values())[0][0]
checkpoints_dir = list(model_dict.values())[0][1]
weight_dict = torch.load(os.path.join(checkpoints_dir, 'best_model.pth'))
model.load_state_dict(weight_dict['model_state_dict'])
return model
class Open3dVisualizer():
def __init__(self):
self.point_cloud = o3d.geometry.PointCloud()
self.o3d_started = False
self.vis = o3d.visualization.VisualizerWithKeyCallback()
self.vis.create_window()
def __call__(self, points, colors):
self.update(points, colors)
return False
def update(self, points, colors):
coord_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.15, origin=[0, 0, 0])
self.point_cloud.points = points
self.point_cloud.colors = colors
# self.point_cloud.transform([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
# self.vis.clear_geometries()
# Add geometries if it is the first time
if not self.o3d_started:
self.vis.add_geometry(self.point_cloud)
self.vis.add_geometry(coord_mesh)
self.o3d_started = True
else:
self.vis.update_geometry(self.point_cloud)
self.vis.update_geometry(coord_mesh)
self.vis.poll_events()
self.vis.update_renderer()
if __name__ == '__main__':
#这地方改一下
num_classes = 50 # 填写数据集的类别数 如果是s3dis这里就填13 shapenet这里就填50
# color_image = o3d.io.read_image('image/rgb1.jpg')
# depth_image = o3d.io.read_image('image/depth1.png')
txt_path = './myowncloud/666.txt'
# 通过numpy读取txt点云
pcd = np.genfromtxt(txt_path, delimiter=" ")
point_cloud = o3d.geometry.PointCloud()
# 加载点坐标
# txt点云前三个数值一般对应x、y、z坐标,可以通过open3d.geometry.PointCloud().points加载
# 如果有法线或颜色,那么可以分别通过open3d.geometry.PointCloud().normals或open3d.geometry.PointCloud().colors加载
point_cloud.points = o3d.utility.Vector3dVector(pcd[:, :3])
#point_cloud.normals = o3d.utility.Vector3dVector(pcd[:, 3:6])
point_cloud.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.03, max_nn=30))
print(np.asarray(point_cloud.points))
print(np.asarray(point_cloud.normals))
TEST_DATASET = PartNormalDataset(point_cloud, 30000, True)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
drop_last=True)
predict_pcd = Generate_txt_and_3d_img(num_classes, testDataLoader, load_models(), visualize=True)
应该对吧,反正我运行出结果了。这个是把飞机的那个模型的点云抠出来一部分,分割不是特别完美,不过显示已经没问题了。