三维点云学习(5)4-实现Deeplearning-PointNet-1-数据集的批量读取

三维点云学习(5)4-实现Deeplearning-PointNet-1-数据集的批量读取

Github PointNet源码
数据集下载:为40种物体的三维点云数据集
提取码:es14

因为本人初次学习pytorch,所以在数据处理上比较吃力
原github上dataset.py读取的ModelNet40数据集为ply格式,课堂给予的数据集为txt格式,需要对dataset进行修改,如下为我的 testdataset.py 调试文件,可以读取txt格式的点云文件
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

读取txt数据集代码块

eg:注意使用时修改对应的数据集绝对路径

部分运行结果:

Epoch:  0 | Step:  251 | points:  [[[-0.01789839 -0.6943113   0.13418637]
  [-0.29483226 -0.7803244  -0.27808684]
  [-0.3572845   0.17805085  0.13267519]
  ...
  [ 0.39224106 -0.0946335   0.07569175]
  [ 0.00233396  0.7481037   0.23252504]
  [-0.35258114  0.43497837 -0.32852176]]

 [[ 0.80344033 -0.01079553  0.06876279]
  [ 0.5923547   0.00257346 -0.7231698 ]
  [ 0.75024015 -0.24503842  0.20600945]
  ...
  [ 0.58393884  0.01667958 -0.6572023 ]
  [ 0.109657   -0.20571832  0.1225618 ]
  [ 0.7129874  -0.25549927 -0.2599017 ]]

 [[ 0.00235834 -0.45798454 -0.4980599 ]
  [ 0.7268403  -0.4025967  -0.26333734]
  [-0.43558982  0.64690256 -0.5365175 ]
  ...
  [ 0.06876163  0.58287716  0.6605597 ]
  [ 0.62074244  0.22889914 -0.20450239]
  [-0.1456063  -0.54144853 -0.46895012]]

 ...

 [[ 0.06745299  0.0888008  -0.10955164]
  [-0.5263725  -0.21610892 -0.5766182 ]
  [-0.2850928   0.13755152 -0.26902696]
  ...
  [ 0.5999742  -0.22155094  0.5118974 ]
  [-0.55020595  0.01602052 -0.798858  ]
  [ 0.23198262  0.09373132  0.5464645 ]]

 [[-0.18973409 -0.2597553   0.1845103 ]
  [-0.2246311   0.18214802  0.08781255]
  [ 0.13568395 -0.05025252  0.02492691]
  ...
  [ 0.1176176  -0.02466865 -0.10461096]
  [-0.5970161   0.67265266  0.18175025]
  [ 0.6085995   0.6576732  -0.00114515]]

 [[ 0.18190046  0.01715711  0.36286485]
  [-0.33180034 -0.04377893 -0.93247163]
  [ 0.18474244  0.07122478  0.12997185]
  ...
  [-0.29546985  0.0260048  -0.0259549 ]
  [ 0.33861616 -0.03495657  0.3082618 ]
  [-0.36314163  0.06146974 -0.37071046]]] | target:  [14  2 10  2 22 32  3  5  9  0  2 30  5  4 35 33  0 34 32  3 30  3  0 17
  1 24 37  5 16  3  4 18]
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import argparse
import sys
from tqdm import tqdm
import json
from plyfile import PlyData, PlyElement


class ModelNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.split = split
        self.data_augmentation = data_augmentation
        self.fns = []
        with open(os.path.join(root, 'modelnet40_{}.txt'.format(self.split)), 'r') as f:
            for line in f:
                self.fns.append(line.strip())    #line.strip()删除末尾结束符

        self.cat = {}   #类别
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = int(ls[1])
        print(self.cat)       #构建 cat key
        self.classes = list(self.cat.keys())    # class list

    def __getitem__(self, index):
        file_index = self.fns[index]        #eg: index:845 -> bed_0114
        cls_name = file_index.rsplit(sep='_',maxsplit=1)[0]    # eg #airplane,rsplit:右侧分割,maxplit:限定分割符的出现次数
        cls = self.cat[cls_name]    #eg out:[0]    airplane属于[0],bench属于[3]
        file_name = '{}/{}.txt'.format(cls_name,file_index)  #eg airplane/airplane_0001.txt
        pts = np.loadtxt(os.path.join(self.root,file_name),delimiter=',',dtype=float)[:,:3]  #读取每个point的前三个数字,分别为(x,y,z);后三位为法向量不读取
        #随机降采样,提高模型的鲁棒性
        choice = np.random.choice(len(pts), self.npoints, replace=True)
        point_set = pts[choice, :]

        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
        point_set = point_set / dist  # scale

        if self.data_augmentation:
            theta = np.random.uniform(0, np.pi * 2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter

        point_set = torch.from_numpy(point_set.astype(np.float32))
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))
        return point_set, cls

    def __len__(self):
        return len(self.fns)


if __name__ == '__main__':
    datapath = "D:/三维点云学习/深蓝学院课程/第五章/pointnet.pytorch-master/data/modelnet40_normal_resampled/"
    dataset = ModelNetDataset(
        root=datapath,
        npoints=2500,
        split='train')
    test_dataset = ModelNetDataset(
        root=datapath,
        npoints=2500,
        split='test',
        data_augmentation=False)

    dataloder = torch.utils.data.DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )
    testdataloder = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )
    for epoch in range(3):
        for step,data in enumerate(dataloder,0):     #0表示从索引0开始
            points , target = data
            target = target[:,0]
            print('Epoch: ', epoch, '| Step: ', step, '| points: ',
                  points.numpy(), '| target: ', target.numpy())
  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值