PyTorch中加载数据的方法是通过Dataset
和DataLoader
完成的
- Dataset提供一种方式去获取每个数据及其对应的label,告诉我们总共有多少个数据。
- Dataloader为后面的网络提供不同的数据形式,它将一批一批数据进行一个打包。
本节主要学习
- PyTorch常见的数据读取方式
- 构建自己的数据读取流程
我们可以定义自己的Dataset类来实现灵活的数据读取,定义的类需要继承PyTorch自身的Dataset类。主要包含三个函数:
__init__
: 用于向类中传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数
import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path)
val_data = datasets.ImageFolder(val_path)
接下来展示如何构造自己的Dataset类
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir): # 该魔术方法当创建一个事例对象时,会自动调用该函数
self.root_dir = root_dir # self.root_dir 相当于类中的全局变量
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir) # 字符串拼接,根据是Windows或Lixus系统情况进行拼接
self.img_path = os.listdir(self.path) # 获得路径下所有图片的地址
def __getitem__(self,idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "./yourpath"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
print(len(ants_dataset))
print(len(bees_dataset))
train_dataset = ants_dataset + bees_dataset # train_dataset 就是两个数据集的集合了
print(len(train_dataset))
img,label = train_dataset[200]
print("label:",label)
img.show()
构建好Dataset后,就可以使用DataLoader来按批次读入数据了
DataLoader
类的参数配置如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
读入数据代码如下:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
其中:
- batch_size:样本是按“批”读入的,batch_size就是每次读入的样本数
- num_workers:有多少个进程用于读取数据,Windows下该参数设置为0,Linux下常见的为4或者8,根据自己的电脑配置来设置
- shuffle:是否将读入的数据打乱,一般在训练集中设置为True,验证集中设置为False
- drop_last:对于样本最后一部分没有达到批次数的样本,使其不再参与训练
查看加载的数据使用next和iter来完成
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()
文章有问题欢迎交流学习
QQ:3113075063