代码:
import numpy as np import warnings import os from torch.utils.data import Dataset warnings.filterwarnings('ignore') def pc_normalize(pc): centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc def farthest_point_sample(point, npoint): """ Input: xyz: pointcloud data, [N, D] npoint: number of samples Return: centroids: sampled pointcloud index, [npoint, D] """ N, D = point.shape xyz = point[:,:3] centroids = np.zeros((npoint,)) distance = np.ones((N,)) * 1e10 farthest = np.random.randint(0, N) for i in range(npoint): centroids[i] = farthest centroid = xyz[farthest, :] dist = np.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = np.argmax(distance, -1) point = point[centroids.astype(np.int32)] return point class ModelNetDataLoader(Dataset): def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): self.root = root self.npoints = npoint self.uniform = uniform self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.cat = [line.rstrip() for line in open(self.catfile)] self.classes = dict(zip(self.cat, range(len(self.cat)))) self.normal_channel = normal_channel shape_ids = {} shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] assert (split == 'train' or split == 'test') shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] # list of (shape_name, shape_txt_file_path) tuple self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i in range(len(shape_ids[split]))] print('The size of %s data is %d'%(split,len(self.datapath))) self.cache_size = cache_size # how many data points to cache in memory self.cache = {} # from index to (point_set, cls) tuple def __len__(self): return len(self.datapath) def _get_item(self, index): if index in self.cache: point_set, cls = self.cache[index] else: fn = self.datapath[index] cls = self.classes[self.datapath[index][0]] cls = np.array([cls]).astype(np.int32) point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) if self.uniform: point_set = farthest_point_sample(point_set, self.npoints) else: point_set = point_set[0:self.npoints,:] point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])#相当于3列值,做标准化 if not self.normal_channel: point_set = point_set[:, 0:3] if len(self.cache) < self.cache_size: self.cache[index] = (point_set, cls) return point_set, cls def __getitem__(self, index): return self._get_item(index) if __name__ == '__main__': import torch data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,) DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) for point,label in DataLoader: print(point.shape) print(label.shape)
代码解释:
这段代码是用于加载和预处理ModelNet数据集的脚本。
首先导入了所需的库和模块,包括numpy、warnings、os以及torch.utils.data中的Dataset。
然后设置警告忽略,以防止警告信息干扰。
之后定义了一些辅助函数和类。
- `pc_normalize(pc)`:该函数用于对点云数据进行归一化处理,将坐标值减去平均值并除以最大值,以使其范围在[-1,1]之间。
- `farthest_point_sample(point, npoint)`:该函数用于进行最远点采样,从给定的点云中采样出npoint个点作为代表。逐步选择离中心点距离最远的点,并将其添加到结果中。
- `ModelNetDataLoader`:这是一个自定义的数据加载类,继承自`torch.utils.data.Dataset`。它用于加载ModelNet数据集,并根据需要进行采样和预处理。
在这个类中,首先初始化一些参数,如数据集路径、采样点数、是否均匀采样等。
然后根据数据集的类别文件和划分文件,构建数据集的索引。
通过实现`__len__()`和`__getitem__()`方法,可以获取数据集的长度和具体的样本。
在`_get_item()`函数中,根据索引获取点云数据,并进行一系列的处理,如根据是否均匀采样进行采样、归一化等。
最后,在`__main__`函数中,创建了一个`ModelNetDataLoader`对象,并通过`torch.utils.data.DataLoader`创建数据加载器。
遍历数据加载器中的数据,打印点云数据和标签的形状。
总的来说,这段代码定义了一些用于数据加载和预处理的辅助函数和类。它允许从ModelNet数据集中加载并处理点云数据。