三维点云学习(5)4-实现Deeplearning-PointNet-1-数据集的批量读取
Github PointNet源码
数据集下载:为40种物体的三维点云数据集
提取码:es14
因为本人初次学习pytorch,所以在数据处理上比较吃力
原github上dataset.py读取的ModelNet40数据集为ply格式,课堂给予的数据集为txt格式,需要对dataset进行修改,如下为我的 testdataset.py 调试文件,可以读取txt格式的点云文件
读取txt数据集代码块
eg:注意使用时修改对应的数据集绝对路径
部分运行结果:
Epoch: 0 | Step: 251 | points: [[[-0.01789839 -0.6943113 0.13418637]
[-0.29483226 -0.7803244 -0.27808684]
[-0.3572845 0.17805085 0.13267519]
...
[ 0.39224106 -0.0946335 0.07569175]
[ 0.00233396 0.7481037 0.23252504]
[-0.35258114 0.43497837 -0.32852176]]
[[ 0.80344033 -0.01079553 0.06876279]
[ 0.5923547 0.00257346 -0.7231698 ]
[ 0.75024015 -0.24503842 0.20600945]
...
[ 0.58393884 0.01667958 -0.6572023 ]
[ 0.109657 -0.20571832 0.1225618 ]
[ 0.7129874 -0.25549927 -0.2599017 ]]
[[ 0.00235834 -0.45798454 -0.4980599 ]
[ 0.7268403 -0.4025967 -0.26333734]
[-0.43558982 0.64690256 -0.5365175 ]
...
[ 0.06876163 0.58287716 0.6605597 ]
[ 0.62074244 0.22889914 -0.20450239]
[-0.1456063 -0.54144853 -0.46895012]]
...
[[ 0.06745299 0.0888008 -0.10955164]
[-0.5263725 -0.21610892 -0.5766182 ]
[-0.2850928 0.13755152 -0.26902696]
...
[ 0.5999742 -0.22155094 0.5118974 ]
[-0.55020595 0.01602052 -0.798858 ]
[ 0.23198262 0.09373132 0.5464645 ]]
[[-0.18973409 -0.2597553 0.1845103 ]
[-0.2246311 0.18214802 0.08781255]
[ 0.13568395 -0.05025252 0.02492691]
...
[ 0.1176176 -0.02466865 -0.10461096]
[-0.5970161 0.67265266 0.18175025]
[ 0.6085995 0.6576732 -0.00114515]]
[[ 0.18190046 0.01715711 0.36286485]
[-0.33180034 -0.04377893 -0.93247163]
[ 0.18474244 0.07122478 0.12997185]
...
[-0.29546985 0.0260048 -0.0259549 ]
[ 0.33861616 -0.03495657 0.3082618 ]
[-0.36314163 0.06146974 -0.37071046]]] | target: [14 2 10 2 22 32 3 5 9 0 2 30 5 4 35 33 0 34 32 3 30 3 0 17
1 24 37 5 16 3 4 18]
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import argparse
import sys
from tqdm import tqdm
import json
from plyfile import PlyData, PlyElement
class ModelNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.split = split
self.data_augmentation = data_augmentation
self.fns = []
with open(os.path.join(root, 'modelnet40_{}.txt'.format(self.split)), 'r') as f:
for line in f:
self.fns.append(line.strip()) #line.strip()删除末尾结束符
self.cat = {} #类别
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = int(ls[1])
print(self.cat) #构建 cat key
self.classes = list(self.cat.keys()) # class list
def __getitem__(self, index):
file_index = self.fns[index] #eg: index:845 -> bed_0114
cls_name = file_index.rsplit(sep='_',maxsplit=1)[0] # eg #airplane,rsplit:右侧分割,maxplit:限定分割符的出现次数
cls = self.cat[cls_name] #eg out:[0] airplane属于[0],bench属于[3]
file_name = '{}/{}.txt'.format(cls_name,file_index) #eg airplane/airplane_0001.txt
pts = np.loadtxt(os.path.join(self.root,file_name),delimiter=',',dtype=float)[:,:3] #读取每个point的前三个数字,分别为(x,y,z);后三位为法向量不读取
#随机降采样,提高模型的鲁棒性
choice = np.random.choice(len(pts), self.npoints, replace=True)
point_set = pts[choice, :]
point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
point_set = point_set / dist # scale
if self.data_augmentation:
theta = np.random.uniform(0, np.pi * 2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
point_set = torch.from_numpy(point_set.astype(np.float32))
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
return point_set, cls
def __len__(self):
return len(self.fns)
if __name__ == '__main__':
datapath = "D:/三维点云学习/深蓝学院课程/第五章/pointnet.pytorch-master/data/modelnet40_normal_resampled/"
dataset = ModelNetDataset(
root=datapath,
npoints=2500,
split='train')
test_dataset = ModelNetDataset(
root=datapath,
npoints=2500,
split='test',
data_augmentation=False)
dataloder = torch.utils.data.DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
testdataloder = torch.utils.data.DataLoader(
test_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
for epoch in range(3):
for step,data in enumerate(dataloder,0): #0表示从索引0开始
points , target = data
target = target[:,0]
print('Epoch: ', epoch, '| Step: ', step, '| points: ',
points.numpy(), '| target: ', target.numpy())