一、将原始数据处理成PyTorch能够加载的数据
imgs为原始的RGB图像,dpts为其对应的深度图。
_path.txt为加载图片的路径,每一行都是:/rgb0.png /depth0.png
class FireWork_Dataset(torch.utils.data.Dataset):
def __init__(self, data_path, type='train'):
txt_path = ''
if type == 'train':
txt_path = data_path + '/train_path.txt'
if type == 'test':
txt_path = data_path + '/test_path.txt'
fh = open(txt_path, 'r')
imgs = []
dpts = []
for line in fh:
if line is not None:
line = line.rstrip() # 去掉字符串的末尾字符
words = line.split() # 使用空格分隔
imgs.append('.' + words[0])
dpts.append('.' + words[1])
self.imgs = imgs
self.dpts = dpts
def __getitem__(self, index):
img_path = self.imgs[index]
dpt_path = self.dpts[index]
img = Image.open(img_path).convert('RGB')
dpt = Image.open(dpt_path).convert('L') # 加载成单通道的灰度图
img_transform = transforms.Compose([
transforms.Resize(output_size),
transforms.ToTensor()
])
img = img_transform(img)
dpt = img_transform(dpt)
# dpt = scale(dpt)
# dpt = get_depth_ud(dpt)
return img, dpt
def __len__(self):
return len(self.imgs)
二、使用PIL显示加载到图片
def tensor_to_PIL(tensor):
image = tensor.cpu().clone()
image = image.squeeze()
unloader = transforms.Compose([transforms.ToPILImage()])
image = unloader(image)
return image
def load_test():
train_loader, test_loader = getFireWorkDataset()
for imgs, dpts in train_loader:
if torch.cuda.is_available():
imgs = imgs.cuda()
dpts = dpts.cuda()
# 不使用PIL库进行转换,直接使用plt画出来
# img = imgs[0].cpu().permute(1, 2, 0)
# plt.imshow(img)
# plt.show()
dpt_ud = dpts[0][0].data.cpu() #0-100
plt.imshow(dpt_ud)
#plt.show()
# 使用PIL库先转换
img_PIL = tensor_to_PIL(imgs[0]) # 显示rgb原图
dpt_PIL = tensor_to_PIL(torch.cat((dpts[0], dpts[0],dpts[0]), dim=0)) # 将单通道的深度图拼接成原图
plt.imshow(img_PIL)
plt.show()
print(imgs.size())
print(dpts.size())
print(len(test_loader))
break