NYU数据集
NYU数据集由训练集和验证集两部分组成,其中包括了训练图像、深度图像和标签图像三部分组成。其中训练集由795张图像。
构建Dataset
import numpy as np
import os
from torch.utils.data import Dataset
from PIL import Image
# 数据集路径
Depth_Train = ' '
Depth_Test = ' '
Image_Train = ' '
Image_Test = ' '
Label_Train = ' '
Label_Test = ' '
depth_train = os.listdir(Depth_Train)
depth_train = [os.path.join(Depth_Train, img) for img in depth_train]
depth_train.sort()
depth_test = os.listdir(Depth_Test)
depth_test = [os.path.join(Depth_Test, img) for img in depth_test]
depth_test.sort()
image_train = os.listdir(Image_Train)
image_train = [os.path.join(Image_Train, img) for img in image_train]
image_train.sort()
image_test = os.listdir(Image_Test)
image_test = [os.path.join(Image_Test, img) for img in image_test]
image_test.sort()
label_train = os.listdir(Label_Train)
label_train = [os.path.join(Label_Train, img) for img in label_train]
label_train.sort()
label_test = os.listdir(Label_Test)
label_test = [os.path.join(Label_Test, img) for img in label_test]
label_test.sort()
class RGBD_NYU(Dataset):
def __init__(self, transform=None, train_phase=False):
super(RGBD_NYU, self).__init__()
self.transform = transform
self.train_phase = train_phase
def __len__(self):
if self.train_phase:
return len(image_train)
else:
return len(image_test)
def __getitem__(self, idx):
if self.train_phase:
img_dir = image_train
depth_dir = depth_train
label_dir = label_train
else:
img_dir = image_test
depth_dir = depth_test
label_dir = label_test
# 打开图像
image = Image.open(img_dir[idx])
depth = Image.open(depth_dir[idx])
label = Image.open(label_dir[idx])
# 格式转为array
image = np.asarray(image)
depth = np.asarray(depth)
label = np.asarray(label)
sample = {'image': image, 'depth': depth, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
完成了RGBD数据集的构建,接下来对输入图像进行一系列增强方法。