三维点云学习(5)4-实现Deeplearning-PointNet-1-数据集的批量读取
Github PointNet源码
数据集下载:为40种物体的三维点云数据集
提取码:es14
因为本人初次学习pytorch,所以在数据处理上比较吃力
原github上dataset.py读取的ModelNet40数据集为ply格式,课堂给予的数据集为txt格式,需要对dataset进行修改,如下为我的 testdataset.py 调试文件,可以读取txt格式的点云文件
读取txt数据集代码块
eg:注意使用时修改对应的数据集绝对路径
部分运行结果:
Epoch: 0 | Step: 251 | points: [[[-0.01789839 -0.6943113 0.13418637]
[-0.29483226 -0.7803244 -0.27808684]
[-0.3572845 0.17805085 0.13267519]
...
[ 0.39224106 -0.0946335 0.07569175]
[ 0.00233396 0.7481037 0.23252504]
[-0.35258114 0.43497837 -0.32852176]]
[[ 0.80344033 -0.01079553 0.06876279]
[ 0.5923547 0.00257346 -0.7231698 ]
[ 0.75024015 -0.24503842 0.20600945]
...
[ 0.58393884 0.01667958 -0.6572023 ]
[ 0.109657 -0.20571832 0.1225618 ]
[ 0.7129874 -0.25549927 -0.2599017 ]]
[[ 0.00235834 -0.45798454 -0.4980599 ]
[ 0.7268403 -0.4025967 -0.26333734]
[-0.43558982 0.64690256 -0.5365175 ]
...
[ 0.06876163 0.58287716 0.6605597 ]
[ 0.62074244 0.22889914 -0.20450239]
[-0.1456063 -0.54144853 -0.46895012]]
...
[[ 0.06745299 0.0888008 -0.10955164]
[-0.5263725 -0.21610892 -0.5766182 ]
[-0.2850928 0.13755152 -0.26902696]
...
[ 0.5999742 -0.22155094 0.5118974 ]
[-0.55020595 0.01602052 -0.798858 ]
[ 0.23198262 0.09373132 0.5464645 ]]
[[-0.18973409 -0.2597553 0.1845103 ]
[-0.2246311 0.18214802 0.08781255]
[ 0.13568395 -0.05025252 0.02492691]
...
[ 0.1176176 -0.02466865 -0.10461096]
[-0.5970161 0.67265266 0.18175025]
[ 0.6085995 0.6576732 -0.00114515]]
[[ 0.18190046 0.01715711 0.36286485]
[-0.33180034 -0.04377893 -0.93247163]
[ 0.18474244 0.07122478 0.12997185]
...
[-0.29546985 0.0260048 -0.0259549 ]
[ 0.33861616 -0.03495657 0.3082618 ]
[-0.36314163 0.06146974 -0.37071046]]] | target: [14 2 10 2 22 32 3 5 9 0 2 30 5 4 35 33 0 34 32 3 30 3 0 17
1 24 37 5 16 3 4 18]
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import argparse
import sys
from tqdm import tqdm
import json
from plyfile import PlyData, PlyElement
class ModelNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.split = split
self.data_augmentation = data_augmentation
self.fns = []
with open(os.path.join(root, 'modelnet40_{}.txt'.format(self.split)), 'r') as f:
for line in f:
self.fns.append(line.strip()) #line.strip()删除末尾结束符
self.cat = {} #类别
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'pointnet-pytorch-master/misc/modelnet_id.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = int(ls[1])
print(self.cat) #构建 cat key
self.classes = list(self.cat.keys()) # class list
def __getitem__(self, index):
file_index = self.fns[index] #eg: index:845 -> bed_0114
cls_name = file_index.rsplit(sep='_',maxsplit=1)[0] # eg #airplane,rsplit:右侧分割,maxplit:限定分割符的出现次数
cls = self.cat[cls_name] #eg out:[0] airplane属于[0],bench属于[3]
file_name = '{}/{}.txt'.format(cls_name,file_index) #eg airplane/airplane_0001.txt
pts = np.loadtxt(os.path.join(self.root,file_name),delimiter=',',dtype=float)[:,:3] #读取每个point的前三个数字,分别为(x,y,z);后三位为法向量不读取
#随机降采样,提高模型的鲁棒性
choice = np.random.choice(len(pts), self.npoints, replace=True)
#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。
point_set = pts[choice, :]
point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
point_set = point_set / dist # scale
if self.data_augmentation:
theta = np.random.uniform(0, np.pi * 2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
point_set = torch.from_numpy(point_set.astype(np.float32))
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
return point_set, cls
def __len__(self):
return len(self.fns)
if __name__ == '__main__':
datapath = "/home/renzhanqi/workspace/studyMaterialsAndNotes/shenLanLidarProcess/shenLanLidarPrcess-CSDN-Materials/data/modelnet40_normal_resampled/"
dataset = ModelNetDataset(
root=datapath,
npoints=2500,
split='train')
test_dataset = ModelNetDataset(
root=datapath,
npoints=2500,
split='test',
data_augmentation=False)
dataloder = torch.utils.data.DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
testdataloder = torch.utils.data.DataLoader(
test_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
for epoch in range(3):
for step,data in enumerate(dataloder,0): #0表示从索引0开始
points , target = data
target = target[:,0]
print('Epoch: ', epoch, '| Step: ', step, '| points: ',
points.numpy(), '| target: ', target.numpy())