两种创建数据集方式
1.torch的类:from torch.utils.data import Dataset
创建dataset类,重写三种方法,__init__ ,__len__ , __getitem__返回图片和标签
dataset类会自动根据索引index重复调用,完成所有数据的的加载。
最后用datadloader完成数据加载
建立标签和数据的集合(用于未区分文件夹,仅对图片进行过标记)
#导入相关模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np
class AnimalData(Dataset): #继承Dataset
def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
self.root_dir = root_dir #文件目录
self.transform = transform #变换
self.images = os.listdir(self.root_dir)#目录里的所有文件
def __len__(self):#返回整个数据集的大小
return len(self.images)
def __getitem__(self,index):#根据索引index返回dataset[index]
image_index = self.images[index]#根据索引index获取该图片
img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
img = io.imread(img_path)# 读取该图片
label = img_path.split('\\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label,具体根据路径名进行分割。我这里是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,选取最后一个为['cat.0.jpg'],然后使用"."分割,选取[cat]作为该图片的标签
sample = {'image':img,'label':label}#根据图片和标签创建字典
if self.transform:
sample = self.transform(sample)#对样本进行变换
return sample #返回该样本
if __name__=='__main__':
data = AnimalData('E:/Python Project/PyTorch/dogs-vs-cats/train',transform=None)#初始化类,设置数据集所在路径以及变换
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
for i_batch,batch_data in enumerate(dataloader):
print(i_batch)#打印batch编号
print(batch_data['image'].size())#打印该batch里面图片的大小
print(batch_data['label'])#打印该batch里面图片的标签
2文件夹分类后,用imagefolder直接导入,要求一类在一个文件夹,简单快速
from torchvision import transforms,utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
import torch.utils.data
dataset = datasets.ImageFolder(r'F:\\py file\\Folder classification\\PetImages', transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
]))
print(dataset.classes)#获取标签
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
print(len(dataset_loader))
for i_batch, img in enumerate(dataset_loader):
if i_batch == 0:
print(img[1]) #标签转化为编码
fig = plt.figure()
grid = utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
break