前言及简介:
下面以NYU Depth V2 Dataset为例,进行自定义Dataset。数据集官网,
下方实现代码的数据集直接来源为Fangchang Ma此人在原始NYU Depth V2 Dataset基础上做过预处理得到的数据集,样本形式为.h5,
github:https://github.com/fangchangma/sparse-to-dense.pytorch。
NYU Depth V2 Dataset此数据集可用于深度估计,下面是一些基本概念。
深度是指将从图像采集器到场景中各点的距离(深度)。
深度图像就是图像像素值表示深度的图像。
单目深度估计就是使用单张RGB图片来对深度进行预测。
下面的colored_depth是depth经过jet映射的结果:
灰度值越低,就表示距离越近,色温就越高(越冷)。
灰度值越高,就表示距离越远,色温就越低(越暖)。
构建Dataset代码示例
1.数据集文件目录结构
train文件夹之下是场景名,场景名之下是保存为h5形式的样本。一个样本包括一张高宽为(480,640)的RGB图和一张(480,640)的深度图(见上图)。另外,这里的val其实是测试集。
数据集的文件太多,不好用tree打印,部分目录结构如下:
2. 构建自定义Dataset代码示例
可以就着这篇文章一起小火慢炖地看看。
import os
import h5py
from torch.utils.data import DataLoader,Dataset
import numpy as np
#这里是自定义的transforms,你用的话就自行换成torchvision的transforms,然后DIY组合transforms.Compose里面的函数
from dataloaders import transforms
#from torchvision import transforms
def h5_loader(path):
# 在数据集中,每个H5保存一张图片和深度图
# raw shape: rgb c,h,w(3, 480, 640) depth (480, 640) => rgb (480, 640, 3) depth (480, 640)
h5f = h5py.File(path, "r")
rgb = np.array(h5f['rgb'])
rgb = np.transpose(rgb, (1, 2, 0))
depth =np.array(h5f['depth'])
return rgb, depth
class MyDataset(Dataset):
# 根据split(train,val,test)进行筛选
def is_h5file(self,filename):
if self.split == 'train': # 训练集
return (filename.endswith('.h5') and \
'00001.h5' not in filename and '00201.h5' not in filename)
elif self.split == 'val': # 验证集
return ('00001.h5' in filename or '00201.h5' in filename)
elif self.split == 'test': # 测试集
return (filename.endswith('.h5'))
else:
raise (RuntimeError("Invalid dataset split: " + self.split + "\n"
"Supported dataset splits are: train, val,test"))
#构建文件路径索引列表
def getdata(self,path):
img_paths=[]
dirs = os.listdir(path) #训练集返回的是一个包含各个场景名的列表(见下图1),验证集则是official
for dir in dirs:
son_path=os.path.join(path,dir) #训练集:场景文件夹的路径 测试集:到official的路径
for root,dir_name_list,file_list in os.walk(son_path):
# os.walk 返回遍历文件夹的路径、子文件夹名、文件名列表
# 比如在训练集,这里son_path其实已经是场景名文件夹的路径的形式了,比如~/train/basement_001a
# 所以root是son_path,其中都是h5文件,所以文件夹名应该为空,文件名应该是h5文件名的List,结果见下图3
# print("root",root)
# print("dir_name_list",dir_name_list)
# print("file_list",file_list)
for file in file_list:
if self.is_h5file(file): #判断文件拓展名是不是h5
file_path=os.path.join(root,file) #h5文件的路径
img_paths.append(file_path) #塞到列表里
return img_paths #返回一个文件路径的列表
def train_transform(self, rgb, depth): #训练集预处理
depth_np = depth
do_flip = np.random.uniform(0.0, 1.0) > 0.5 # random horizontal flip
transform = transforms.Compose([
transforms.HorizontalFlip(do_flip), # 0.5概率水平翻转
transforms.CenterCrop((228, 304)) #中心裁剪,裁成(228,304)的大小
])
rgb_np = transform(rgb)
rgb_np = self.color_jitter(rgb_np) # random color jittering
rgb_np = np.asfarray(rgb_np, dtype='float') / 255 #对RGB归一化
depth_np = transform(depth_np) #depth图像跟rgb进行相同的几何变换,不需要色彩变换
return rgb_np,depth_np
def val_transform(self, rgb, depth): #验证集、测试集预处理
depth_np = depth
transform = transforms.Compose([
transforms.CenterCrop((228, 304)),
])
rgb_np = transform(rgb)
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
depth_np = transform(depth_np)
return rgb_np, depth_np
def __init__(self,path,split,loader):
self.split=split
self.imgs_path=self.getdata(path)
self.color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) # (亮度,对比度,饱和度)在1-0.4,1+0.4间抖动
if split == 'train':
self.transform = self.train_transform
elif split == 'val':
self.transform = self.val_transform
elif split == 'test':
self.transform = self.val_transform
else:
raise (RuntimeError("Invalid dataset split: " + split + "\n"
"Supported dataset splits are: train, val,test"))
self.loader = loader
def __getitem__(self, index):
rgb,depth=self.loader(self.imgs_path[index]) #从路径列表中加载样本
input_np, depth_np = self.transform(rgb, depth) # input_np (228, 304, 3) depth_np (228 304)
#将预处理过的数据从numpy=>tensor
to_tensor = transforms.ToTensor() #(H,W,C)=>(C,H,W),并转成Tensor
input_tensor = to_tensor(input_np) # ([3, 228, 304])
depth_tensor = to_tensor(depth_np)
depth_tensor = depth_tensor.unsqueeze(0) # [228,304]=> ([1,228,304])
return input_tensor, depth_tensor
def __len__(self):
return len(self.imgs_path)
## 测试代码
train_path=os.path.join("Data","nyudepthv2","train")
val_path=train_path
test_path=os.path.join("Data","nyudepthv2","val")
train_dataset=MyDataset(train_path,'train',h5_loader)
val_dataset=MyDataset(val_path,"val",h5_loader)
test_dataset=MyDataset(test_path,'test',h5_loader)
train_loader = DataLoader(train_dataset,
batch_size=1, shuffle=True, num_workers=0, pin_memory=False)
val_loader = DataLoader(val_dataset,
batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
test_loader = DataLoader(test_dataset,
batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))
for i, trainsample in enumerate(train_loader):
print("i:",i)
Data, Label = trainsample
print("data:", Data.shape)
print("label:", Label.shape)
if(i==3):
break
3. 部分结果展示
图1 :数据集train文件夹的场景文件夹名
图2:经过水平翻转和中心裁减、颜色扰动的训练集RGB图
图3:os.walk结果