首先要按照大佬目录上所说 :
git clone https://github.com/fxia22/pointnet.pytorch
cd pointnet.pytorch
pip install -e .
这一步做完是下载了github上的代码
然后需要安装可视化的工具
cd script
bash build.sh #build C++ code for visualization
bash download.sh #download dataset
在这里多说一嘴,我用的是Anaconda的虚拟环境。
conda create -n pyt python=3.7 anaconda
conda activate pyt
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
作者没有说,但是看了代码还需要用到一些包:
conda install tqdm
conda install plyfile
pip install opencv-contrib-python
好了,如果以上都没有问题,那么可以进行下一步了
文档的目录如下
pointnet.pytorch
├── misc
├── pointnet
│ └── __pycache__
├── scripts
├── shapenetcore_partanno_segmentation_benchmark_v0
│ ├── 02691156
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 02773838
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 02954340
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 02958343
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03001627
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03261776
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03467517
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03624134
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03636649
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03642806
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03790512
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03797390
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 03948459
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 04099429
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 04225987
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ ├── 04379243
│ │ ├── points
│ │ ├── points_label
│ │ └── seg_img
│ └── train_test_split
└── utils
├── cls
├── __pycache__
└── seg
首先说一下数据集
数据集存放在shapenetcore_partanno_segmentation_benchmark_v0里,里面有16类样本的ShapeNet,点云文件格式都是以.pts文件的格式结尾的。
如果我们想看里面的点云怎么办呢?那就来看看!在utils里建一个名为show_points.py的文件。内容如下
'''
可视化文件夹下的点云数据
输入:n*3的矩阵
'''
from __future__ import print_function
import os,sys
sys.path.append(os.path.realpath(".."))
from show3d_balls import showpoints
import argparse
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
from pointnet.model import PointNetDenseCls
import matplotlib.pylab as plt
import sys
sys.path.append('/home/liqunzhao/pointnet.pytorch')
points=np.loadtxt('../shapenetcore_partanno_segmentation_benchmark_v0/02691156/points/1a04e3eab45ca15dd86060f189eb133.pts',dtype=np.float32) #预测只能输入float32的格式的数据
print(points.shape)
# 可视化
cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
# 可视化点云
showpoints(points)
#采样到2500个点
choice = np.random.choice(len(points), 2500, replace=True)
# print('choice:{}'.format(choice))
points = points[choice, :]
print('points[choice, :]:{}'.format(points))
point_np=points
# 载入模型
state_dict = torch.load('./seg/seg_model_Chair_1.pth')
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0])
classifier.load_state_dict(state_dict)
classifier.eval() #设置为评估状态
# 点云转置
points=torch.from_numpy(points)
print(points.shape)
point = points.transpose(1, 0).contiguous()
print('point.transpose(1, 0).shape: ',point.shape)
point = Variable(point.view(1, point.size()[0], point.size()[1])) #转为torch变量1,3,2500
print('--------------------')
print(point.dtype)
# point=torch.tensor(point,dtype=torch.float32)
pred, _, _ = classifier(point) #分割
print(pred)
pred_choice = pred.data.max(2)[1]
print(pred_choice.numpy()) #输出每一个点的预测类别
# print(pred_choice.size())
print(pred_choice.numpy()[0]) #[1 1 1 ... 1 1 1]
pred_color = cmap[pred_choice.numpy()[0], :] #根据分类结果显示颜色
print('\npred_color: ',pred_color.shape,'\n')
print(pred_color.dtype)
# point_np=point.numpy().reshape(2500,3)
print(point_np.shape)
print(point_np.dtype)
showpoints(point_np, pred_color,pred_color) #pred_colord的为(2500, 3)的矩阵
conda activate pyt
python show_points.py
从点云库里读取到的飞机点云。同样的,如果我们修改
points=np.loadtxt('../shapenetcore_partanno_segmentation_benchmark_v0/02691156/points/1a04e3eab45ca15dd86060f189eb133.pts',dtype=np.float32)
也可以读取到不同的点云