自定义Dataset例子(torch)

前言及简介:

下面以NYU Depth V2 Dataset为例,进行自定义Dataset数据集官网,

下方实现代码的数据集直接来源为Fangchang Ma此人在原始NYU Depth V2 Dataset基础上做过预处理得到的数据集,样本形式为.h5,
github:https://github.com/fangchangma/sparse-to-dense.pytorch

NYU Depth V2 Dataset此数据集可用于深度估计,下面是一些基本概念。
深度是指将从图像采集器到场景中各点的距离(深度)。
深度图像就是图像像素值表示深度的图像。
单目深度估计就是使用单张RGB图片来对深度进行预测。
下面的colored_depth是depth经过jet映射的结果:
灰度值越低,就表示距离越近,色温就越高(越冷)。
灰度值越高,就表示距离越远,色温就越低(越暖)。
在这里插入图片描述

构建Dataset代码示例

1.数据集文件目录结构

train文件夹之下是场景名,场景名之下是保存为h5形式的样本。一个样本包括一张高宽为(480,640)的RGB图和一张(480,640)的深度图(见上图)。另外,这里的val其实是测试集。
数据集的文件太多,不好用tree打印,部分目录结构如下:

2. 构建自定义Dataset代码示例

可以就着这篇文章一起小火慢炖地看看。

import os
import h5py
from torch.utils.data import DataLoader,Dataset
import numpy as np

#这里是自定义的transforms,你用的话就自行换成torchvision的transforms,然后DIY组合transforms.Compose里面的函数
from dataloaders import transforms
#from torchvision import  transforms


def h5_loader(path):
    # 在数据集中,每个H5保存一张图片和深度图
    # raw shape: rgb   c,h,w(3, 480, 640) depth (480, 640)  => rgb (480, 640, 3) depth (480, 640)
    h5f = h5py.File(path, "r")
    rgb = np.array(h5f['rgb'])
    rgb = np.transpose(rgb, (1, 2, 0))
    depth =np.array(h5f['depth'])
    return rgb, depth

class MyDataset(Dataset):

    # 根据split(train,val,test)进行筛选
    def is_h5file(self,filename):
        if self.split == 'train':  # 训练集
            return (filename.endswith('.h5') and \
                    '00001.h5' not in filename and '00201.h5' not in filename)
        elif self.split == 'val':  # 验证集
            return ('00001.h5' in filename or '00201.h5' in filename)
        elif self.split == 'test': # 测试集
            return (filename.endswith('.h5'))
        else:
            raise (RuntimeError("Invalid dataset split: " + self.split + "\n"
                                                                         "Supported dataset splits are: train, val,test"))
    #构建文件路径索引列表
    def getdata(self,path):
        img_paths=[]
        dirs = os.listdir(path)  #训练集返回的是一个包含各个场景名的列表(见下图1),验证集则是official
        for dir in dirs:
            son_path=os.path.join(path,dir)  #训练集:场景文件夹的路径 测试集:到official的路径
            for root,dir_name_list,file_list in os.walk(son_path):
                # os.walk 返回遍历文件夹的路径、子文件夹名、文件名列表
                # 比如在训练集,这里son_path其实已经是场景名文件夹的路径的形式了,比如~/train/basement_001a
                # 所以root是son_path,其中都是h5文件,所以文件夹名应该为空,文件名应该是h5文件名的List,结果见下图3
                # print("root",root)
                # print("dir_name_list",dir_name_list)
                # print("file_list",file_list)
                for file in file_list:
                    if self.is_h5file(file): #判断文件拓展名是不是h5
                        file_path=os.path.join(root,file) #h5文件的路径
                        img_paths.append(file_path)   #塞到列表里
        return img_paths        #返回一个文件路径的列表

    def train_transform(self, rgb, depth): #训练集预处理
        depth_np = depth
        do_flip = np.random.uniform(0.0, 1.0) > 0.5  # random horizontal flip

        transform = transforms.Compose([
            transforms.HorizontalFlip(do_flip),  # 0.5概率水平翻转
            transforms.CenterCrop((228, 304))  #中心裁剪,裁成(228,304)的大小

        ])
        rgb_np = transform(rgb)

        rgb_np = self.color_jitter(rgb_np)  # random color jittering
        rgb_np = np.asfarray(rgb_np, dtype='float') / 255  #对RGB归一化
        depth_np = transform(depth_np)   #depth图像跟rgb进行相同的几何变换,不需要色彩变换
        return rgb_np,depth_np

    def val_transform(self, rgb, depth):  #验证集、测试集预处理
        depth_np = depth
        transform = transforms.Compose([
            transforms.CenterCrop((228, 304)),
        ])
        rgb_np = transform(rgb)
        rgb_np = np.asfarray(rgb_np, dtype='float') / 255
        depth_np = transform(depth_np)

        return rgb_np, depth_np

    def __init__(self,path,split,loader):
        self.split=split
        self.imgs_path=self.getdata(path)
        self.color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)  # (亮度,对比度,饱和度)在1-0.4,1+0.4间抖动
        if split == 'train':
            self.transform = self.train_transform
        elif split == 'val':
            self.transform = self.val_transform
        elif split == 'test':
            self.transform = self.val_transform
        else:
            raise (RuntimeError("Invalid dataset split: " + split + "\n"
                                                                    "Supported dataset splits are: train, val,test"))
        self.loader = loader
    def __getitem__(self, index):
        rgb,depth=self.loader(self.imgs_path[index])  #从路径列表中加载样本
        input_np, depth_np = self.transform(rgb, depth)  # input_np (228, 304, 3) depth_np (228 304)

        #将预处理过的数据从numpy=>tensor
        to_tensor = transforms.ToTensor()  #(H,W,C)=>(C,H,W),并转成Tensor
        input_tensor = to_tensor(input_np) # ([3, 228, 304])
        depth_tensor = to_tensor(depth_np)
        depth_tensor = depth_tensor.unsqueeze(0)  #  [228,304]=> ([1,228,304])


        return input_tensor, depth_tensor
    def __len__(self):
        return len(self.imgs_path)




## 测试代码
train_path=os.path.join("Data","nyudepthv2","train")
val_path=train_path
test_path=os.path.join("Data","nyudepthv2","val")

train_dataset=MyDataset(train_path,'train',h5_loader)
val_dataset=MyDataset(val_path,"val",h5_loader)
test_dataset=MyDataset(test_path,'test',h5_loader)


train_loader = DataLoader(train_dataset,
        batch_size=1, shuffle=True, num_workers=0, pin_memory=False)
val_loader = DataLoader(val_dataset,
        batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
test_loader = DataLoader(test_dataset,
        batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))
for i, trainsample in enumerate(train_loader):
    print("i:",i)
    Data, Label = trainsample
    print("data:", Data.shape)
    print("label:", Label.shape)
    if(i==3):
        break
        

3. 部分结果展示

图1 :数据集train文件夹的场景文件夹名在这里插入图片描述
图2:经过水平翻转和中心裁减、颜色扰动的训练集RGB图

图3:os.walk结果
在这里插入图片描述

  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值