在学习深度学习初期,被各种算法、各种名词吸引,学习了一些相关知识、框架和方法,也跑过Mnist分类以及iris回归。但是当真正拿到一个陌生的数据集时,需要重头开始搭建一个完整的模型时,常常会感觉到无从下手,之前跑的模型都是用人家整理好的数据,一行代码就能把数据加载进行直接使用,自己完全不关心数据的加载、处理过程,只关心模型能不能训练,训练结果怎么样。这篇文章简单记录一下pytorch中自定义数据集的使用方法。
在pytorch中涉及到数据集加载的模块有两个,一个是DataSet,另一个是DataLoader。pytorch中数据加载的核心是torch.utils.data.DataLoader类,支持映射类型(map-style)和迭代类型(iterable-style)的数据集。
DataSet的描述如下:
根据描述可知,Dataset是一个抽象类,子类需要实现其中的__getitem__()方法用于获取数据集中的元素,实现__len__()方法用于获取数据集的大小。
DataLoader的构造函数定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其中,dataset是一个加载数据集的对象,batch_size是批量大小,num_workers表示用几个子线程来并行加载数据。
下面讲一下pytorch数据加载支持的两种类型:
1、映射类型map-style
根据官网的描述,map-style类型的数据集可以简单理解为键值对类型的数据集,键可以是字典的key,也可以是数组的index,总之可以通过类似于dataset[idx]这种方式进行访问。要加载map-style类型的数据,需要重写Dataset的__getitem__()和__len__()方法。
2、迭代类型iterable-style
加载iterable-style类型的数据,需要实现IterableDataset类的__iter__()方法,适用于加载流式等不便于进行shuffle的数据。
下面通过两种典型场景讲一下使用pytorch加载map-style类型的数据集方法。
场景一:加载pandas.DataFrame或numpy数组,此种场景常出现在处理回归问题时加载原始数据集。
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
# 加载pandas.DataFrame,需要使用values将DataFrame先转换为numpy数组
# 构造numpy数组
data_X = np.random.randn(100, 5)
data_y = 3 * data_X + 5
# 自定义Dataset的子类
class MyDataset(Dataset):
# 构造器初始化方法
def __init__(self, data_X, data_y):
self.data_X = data_X
self.data_y = data_y
# 重写getitem方法用于通过idx获取数据内容
def __getitem__(self, idx):
return self.data_X[idx], self.data_y[idx]
# 重写len方法获取数据集大小
def __len__(self):
return self.data_X.shape[0]
# 构造Dataset对象
dataset = MyDataset(data_X, data_y)
# 构造DataLoader对象
dataloader = DataLoader(dataset, batch_size=16, num_workers=0,shuffle=True, drop_last=False)
for batch_X, batch_y in dataloader:
print(batch_X.shape, batch_y.shape)
输出结果如下:
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([4, 5]) torch.Size([4, 5])
场景二:加载磁盘上的图像数据,此场景常出现在处理图像分类问题时加载图像数据集。
图像数据的存放格式如下,猫的图片存放在cat目录下,狗的图片存放在dog目录下:
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision
import random
from matplotlib import pyplot as plt
class MyDataset(Dataset):
# 构造器初始化方法
def __init__(self, filenames, labels, transforms=None):
self.filenames = filenames
self.labels = labels
self.transforms = transforms
# 重写getitem方法用于通过idx获取数据内容
def __getitem__(self, idx):
# 使用Pillow Image读取图片文件
image = Image.open(self.filenames[idx]).convert("RGB")
# 对图像数据进行转换
if self.transforms is not None:
image = self.transforms(image)
return image, self.labels[idx]
# 重写len方法获取数据集大小
def __len__(self):
return len(self.filenames)
def show_image(images, labels, classes):
fig, axes = plt.subplots(1, 4, figsize=(15, 8))
for index, image in enumerate(images):
# pytorch中Tensor的shape是[C, H, W],使用matplotlib显示时,需要转换shape到[H, W, C]
image = image.numpy().transpose(1, 2, 0)
label = labels[index]
axes[index].set_title(classes[label])
axes[index].imshow(image)
# 定义图像预处理转换方法
transforms = torchvision.transforms.Compose(
[
# torchvision.transforms处理的目标是Image对象
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.RandomGrayscale(p=0.3),
# 将Image对象转换为Tensor张量
torchvision.transforms.ToTensor()
]
)
image_dataset = torchvision.datasets.ImageFolder("../data/cat_and_dog")
# image_dataset.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_dataset.samples)
# image_dataset.classes 列表中存放的类别顺序与image_dataset.samples中存放的类别索引编号相对应
classes = image_dataset.classes
# print(image_dataset.samples[:5])
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_dataset.samples:
# print(image_path, label)
filenames.append(image_path)
labels.append(label)
dataset = MyDataset(filenames, labels, transforms)
dataloader = DataLoader(dataset,batch_size=4, num_workers=0, shuffle=True, drop_last=False)
for images, labels in dataloader:
print(images.shape, labels)
# 显示读取到的图像数据,并验证类别信息是否真确
show_image(images, labels, classes)
输出结果如下:
torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 1])
torch.Size([4, 3, 224, 224]) tensor([1, 1, 0, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 0, 1, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 0])