基于pytorch实现RGBD数据集的加载(以NYU数据集为例)

NYU数据集

NYU数据集由训练集和验证集两部分组成,其中包括了训练图像、深度图像和标签图像三部分组成。其中训练集由795张图像。
训练图像
训练深度图像
训练标签图像

构建Dataset

import numpy as np
import os
from torch.utils.data import Dataset
from PIL import Image


# 数据集路径
Depth_Train = ' '
Depth_Test = ' '

Image_Train = ' '
Image_Test = ' '

Label_Train = ' '
Label_Test = ' '

depth_train = os.listdir(Depth_Train)
depth_train = [os.path.join(Depth_Train, img) for img in depth_train]
depth_train.sort()

depth_test = os.listdir(Depth_Test)
depth_test = [os.path.join(Depth_Test, img) for img in depth_test]
depth_test.sort()

image_train = os.listdir(Image_Train)
image_train = [os.path.join(Image_Train, img) for img in image_train]
image_train.sort()

image_test = os.listdir(Image_Test)
image_test = [os.path.join(Image_Test, img) for img in image_test]
image_test.sort()

label_train = os.listdir(Label_Train)
label_train = [os.path.join(Label_Train, img) for img in label_train]
label_train.sort()

label_test = os.listdir(Label_Test)
label_test = [os.path.join(Label_Test, img) for img in label_test]
label_test.sort()


class RGBD_NYU(Dataset):
    
    def __init__(self, transform=None, train_phase=False):
        super(RGBD_NYU, self).__init__()
        
        self.transform = transform
        self.train_phase = train_phase

    def __len__(self):
        if self.train_phase:
            return len(image_train)
        
        else:
            return len(image_test)
        
    def __getitem__(self, idx):
        if self.train_phase:
            img_dir = image_train
            depth_dir = depth_train
            label_dir = label_train

        else:
            img_dir = image_test
            depth_dir = depth_test
            label_dir = label_test

        # 打开图像
        image = Image.open(img_dir[idx])
        depth = Image.open(depth_dir[idx])
        label = Image.open(label_dir[idx])

        # 格式转为array
        image = np.asarray(image)
        depth = np.asarray(depth)
        label = np.asarray(label)

        sample = {'image': image, 'depth': depth, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample

完成了RGBD数据集的构建,接下来对输入图像进行一系列增强方法。

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卡子爹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值