PIL_深度估计的数据集处理

一、将原始数据处理成PyTorch能够加载的数据

        imgs为原始的RGB图像,dpts为其对应的深度图。

        _path.txt为加载图片的路径,每一行都是:/rgb0.png   /depth0.png 

class FireWork_Dataset(torch.utils.data.Dataset):
    def __init__(self, data_path, type='train'):
        txt_path = ''
        if type == 'train':
            txt_path = data_path + '/train_path.txt'
        if type == 'test':
            txt_path = data_path + '/test_path.txt'
        fh = open(txt_path, 'r')
        imgs = []
        dpts = []
        for line in fh:
            if line is not None:
                line = line.rstrip() # 去掉字符串的末尾字符
                words = line.split() # 使用空格分隔
                imgs.append('.' + words[0])
                dpts.append('.' + words[1])

        self.imgs = imgs
        self.dpts = dpts

    def __getitem__(self, index):
        img_path = self.imgs[index]
        dpt_path = self.dpts[index]

        img = Image.open(img_path).convert('RGB')
        dpt = Image.open(dpt_path).convert('L') # 加载成单通道的灰度图

        img_transform = transforms.Compose([
            transforms.Resize(output_size),
            transforms.ToTensor()
        ])

        img = img_transform(img)
        dpt = img_transform(dpt)
        # dpt = scale(dpt)
        # dpt = get_depth_ud(dpt)
        return img, dpt

    def __len__(self):
        return len(self.imgs)

二、使用PIL显示加载到图片

def tensor_to_PIL(tensor):
    image = tensor.cpu().clone()
    image = image.squeeze()
    unloader = transforms.Compose([transforms.ToPILImage()])
    image = unloader(image)
    return image

def load_test():
    train_loader, test_loader = getFireWorkDataset()
    for imgs, dpts in train_loader:
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            dpts = dpts.cuda()
        # 不使用PIL库进行转换,直接使用plt画出来
        # img = imgs[0].cpu().permute(1, 2, 0)
        # plt.imshow(img)
        # plt.show()
        dpt_ud = dpts[0][0].data.cpu() #0-100
        plt.imshow(dpt_ud)
        #plt.show()
        
        # 使用PIL库先转换
        img_PIL = tensor_to_PIL(imgs[0]) # 显示rgb原图
        dpt_PIL = tensor_to_PIL(torch.cat((dpts[0], dpts[0],dpts[0]), dim=0)) # 将单通道的深度图拼接成原图
        plt.imshow(img_PIL)
        plt.show()
        print(imgs.size())
        print(dpts.size())
        print(len(test_loader))
        break

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值