毕设需要,复现一下PointNet++的对象分类、零件分割和场景分割,找点灵感和思路,做个踩坑记录。
下载代码
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
我的运行环境是pytorch1.7+cuda11.0。
训练
PointNet++代码能实现3D对象分类、对象零件分割和语义场景分割。
对象分类
下载数据集ModelNet40,并存储在文件夹data/modelnet40_normal_resampled/
。
## e.g., pointnet2_ssg without normal features
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
python test_classification.py --log_dir pointnet2_cls_ssg
## e.g., pointnet2_ssg with normal features
python train_classification.py --model pointnet2_cls_ssg --use_normals --log_dir pointnet2_cls_ssg_normal
python test_classification.py --use_normals --log_dir pointnet2_cls_ssg_normal
## e.g., pointnet2_ssg with uniform sampling
python train_classification.py --model pointnet2_cls_ssg --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
python test_classification.py --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
- 主文件夹下运行代码
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
时可能会报错:
ImportError: cannot import name 'PointNetSetAbstraction'
原因是pointnet2_cls_ssg.py文件import时的工作目录时models文件夹,但是实际运行的工作目录时models的上级目录,因此需要在pointnet2_cls_ssg.py里把from pointnet2_utils import PointNetSetAbstraction
改成from models.pointnet2_utils import PointNetSetAbstraction
。
参考README.md文件,分类不是我的主攻点,这里就略过了。
零件分割
零件分割是将一个物体的各个零件分割出来,比如把椅子的椅子腿分出来。
下载数据集ShapeNet,并存储在文件夹data/shapenetcore_partanno_segmentation_benchmark_v0_normal/
。
运行也很简单:
## e.g., pointnet2_msg
python train_partseg.py --model pointnet2_part_seg_msg --normal --log_dir pointnet2_part_seg_msg
python test_partseg.py --normal --log_dir pointnet2_part_seg_msg
shapenet数据集txt文件格式:前三个点是xyz,点云的位置坐标,后三个点是点云的法向信息,最后一个点是这个点所属的小类别,即1表示所属50个小类别中的第一个。
写个代码用open3d可视化shapenet数据集的txt文件(随机配色):
import open3d as o3d
import numpy as np
'''
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'data_utils'))
'''
txt_path = '/home/lin/CV_AI_learning/Pointnet_Pointnet2_pytorch-master/data/shapenetcore_partanno_segmentation_benchmark_v0_normal/02691156/1b3c6b2fbcf834cf62b600da24e0965.txt'
# 通过numpy读取txt点云
pcd = np.genfromtxt(txt_path, delimiter=" ")
pcd_vector = o3d.geometry.PointCloud()
# 加载点坐标
# txt点云前三个数值一般对应x、y、z坐标,可以通过open3d.geometry.PointCloud().points加载
# 如果有法线或颜色,那么可以分别通过open3d.geomet