用pytorch读取数据,确实要比tensorflow简单,但是也得熟悉半个小时左右.
下面总结下我的体验,直接用代码
(1)torch.utils.data.Dataset
(2)torch.utils.data.DataLoader
这两个类搭配的数据读取代码:
import os
import glob
import cv2
import numpy as np
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
#第一种数据读取方式
transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.RandomHorizontalFlip(),
T.RandomSizedCrop(224),
T.ToTensor(),#将图片从0-255变为0-1
T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])#标准化到[-1,1]
])
class Test_Data(Dataset):
def __init__(self,data_root,mask_root,transforms=None):
data_image = glob.glob(data_root+'/*.jpg')
self.data_image = data_image
mask_image = glob.glob(mask_root+'/*.jpg')
self.mask_image = mask_image
self.transforms = transforms
def __getitem__(self, index):
data_image_path = self.data_image[index]
mask_image_path = self.mask_image[index]
image_data = cv2.imread(data_image_path,-1)
mask_data = cv2.imread(mask_image_path,-1)
if self.transforms:
image_data = self.transforms(image_data)
mask_data = self.transforms(mask_data)
return image_data,mask_data
def __len__(self):
return len(self.data_image)
dataset = Test_Data(data_root='../../test_image/37_simple/0001/data',mask_root='../../test_image/37_simple/0001/mask')
#第一种调用,不常用
for data,mask in dataset:
print(data.shape,mask.shape)
下面两种应该在训练过程中更加好:
test_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, pin_memory=True, shuffle=True,drop_last=True)
#第一种
# for i,data in enumerate(test_data_loader,0):
# print(data[0].shape,'..')
# print(data[1].shape,'...')
#第2种
for data_batch,mask_batch in test_data_loader:
print(data_batch.size(),mask_batch.size())
还有一种通过from torchvision.datasets import ImageFolder来访问文件数据
不过,我觉得这种更加适合在分类任务中应用
from torchvision.datasets import ImageFolder
#
dataset_data = ImageFolder('../../test_image/37_simple/0001/',transform=None)
#
print(dataset_data.class_to_idx)
print(dataset_data.img)
这种还没有用过