数据加载的两种方式:
1. 直接采用pytorch官方定义的方法: torchvision.datasets
包含一些常见的数据集
数据集无需转化为图片,直接将下载的压缩包放在root=opt.dataset_dir下即可
缺点是自己无法自定义划分数据集了
import torchvision.datasets
import torchvision.transforms as transforms
#一些转化方法
transform=transforms.Compose([transforms.Grayscale(1),
transforms.Resize(opt.picSize),
transforms.ToTensor()])
dataset=torchvision.datasets.CIFAR10(root=opt.dataset_dir,
#split=opt.data_split,
train=True, #采用训练集
transform=transform,
download=True)#是否从互联网下载
dataloader=DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu)
官方文档都有解释:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
2. 自己重写dataset方法进行加载
importtorch.utils.dataasdataf
fromtorch.utils.dataimportDataset
Class dataset(Dataset):
def __init__(self,input_dir):
super(dataset,self).__init__()
path_list=os.listdir(input_dir)
path_list.sort(key=lambdax:int(x.split('.')[0])) #将图片进行排序
self.input_filenames=[os.path.join(input_dir,x)forxinpath_list]
def__getitem__(self,index):
input=cv2.imread(self.input_filenames[index],0)
input=cv2.resize(input,(32,32))
input=ToTensor()(input)
#input=input.float()
returninput
def__len__(self):
returnlen(self.input_filenames)
test_img=dataset('./test/')#调用上面写的dataset,path写图片文件路径
dataloader=dataf.DataLoader(test_img,batch_size=1,shuffle=True)
读取dataloader: for i ,(img,target) in dataloader
其中img的形状:(batch_size, channel, size, size)