在Pytorch中定义数据集主要涉及到两个主要的类:Dataset、DataLoader。
Dataset类
Dataset
类是Pytorch
中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数__len__、__getitem__必须被重载,否则将会触发错误提示:
其中__len__
应该返回数据集的大小,而__getitem__
实现可以通过索引来返回图像数据的功能。
我们要定义自己的数据集类,首先继承上面的Dataset
类,然后在__init__()
方法中对数据集进行整理,得到图像的路径,给图片打标签,划分数据集等。
另外,如果我们需要在读取数据的同时对图像进行增强的话,可以在__getitem__(self, index)
函数中设置图像增强的代码,图像增强的方法可以使用Pytorch
内置的图像增强方式,也可以使用自定义或者其他的图像增强库。这个很灵活,当然要记住一点,在Pytorch
中得到的图像必须是tensor
,也就是说我们还需要在__getitem__中将读取到的数据转换为tensor。
DataLoader类
Dataset
类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:
- 可以分批次读取:batch-size
- 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
- 可以并行加载数据(利用多核处理器加快载入数据的效率)
这时候就需要Dataloader
类了,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)。Dataloader
这个类并不需要我们自己设计代码,我们只需要利用DataLoader
类读取我们设计好的Dataset子类
即可:
# 利用dataloader读取我们的数据对象,并设定batch-size和工作进程
loader = DataLoader(train_dataset, batch_size=16, num_workers=4, shuffle=True)
这时候通过loader
返回的数据就是按照batch_size
来返回特定数量的训练数据的tensor;
利用了多进程,读取数据的速度相比单进程快很多;设置了数据的随机读取,打乱了数据集分布的顺序。
参考:https://www.cnblogs.com/ranjiewen/p/10128046.html
实例
下面通过网络上收集的神奇宝贝图片,制作图像分类数据集。
数据集链接:https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw
提取码:dsxl
上面的数据集中有1168张宝可梦的图片,其中皮卡丘234张、超梦239张、杰尼龟223张、小火龙238、张妙蛙种子234张。
下载后的目录结构如下:
每个目录由神奇宝贝名字命名,对应目录有下是该神奇宝贝的图片,图片的格式有jpg、png、jpeg三种。
数据集的划分如下:
训练集60%,验证集20%,测试集20%。
代码实现:
#coding=utf-8
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
# 返回指定目录下的文件列表,并对文件列表进行排序,
# os.listdir每次返回目录下的文件列表顺序会不一致,
# 排序是为了每次返回文件列表顺序一致
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉非目录文件
if not os.path.isdir(os.path.join(root, name)):
continue
#构建字典,名字:0~4数字
self.name2label[name] = len(self.name2label.keys())
# eg: {'squirtle': 4, 'bulbasaur': 0, 'pikachu': 3, 'mewtwo': 2, 'charmander': 1}
print(self.name2label)
# image, label
self.images, self.labels = self.load_csv("images.csv")
# 对数据集进行划分
if mode == "train": # 60%
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode == "val": # 20% = 60%~80%
self.images = self.images[int(0.6*len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8 * len(self.labels))]
else: # 20% = 80%~100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
# 将目录下的图片路径与其对应的标签写入csv文件,
# 并将csv文件写入的内容读出,返回图片名与其标签
def load_csv(self, filename):
"""
:param filename:
:return:
"""
# 是否已经存在了cvs文件
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
# 获取指定目录下所有的满足后缀的图像名
# pokemon/mewtwo/00001.png
images += glob.glob(os.path.join(self.root, name, "*.png"))
images += glob.glob(os.path.join(self.root, name, "*.jpg"))
images += glob.glob(os.path.join(self.root, name, "*.jpeg"))
# 1165 'pokemon/pikachu/00000058.png'
print(len(images), images)
# 将元素打乱
random.shuffle(images)
with open(os.path.join(self.root, filename), mode="w", newline="") as f:
writer = csv.writer(f)
for img in images: # 'pokemon/pikachu/00000058.png'
name = img.split(os.sep)[-2]
label = self.name2label[name]
# 将图片路径以及对应的标签写入到csv文件中
# 'pokemon/pikachu/00000058.png', 0
writer.writerow([img, label])
print("writen into csv file: ", filename)
# 如果已经存在了csv文件,则读取csv文件
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon/pikachu/00000058.png', 0
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std = mean
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
# idx~[0~len(images)]
# self.images, self.labels
# img: 'pokemon/bulbasaur/00000000.png'
# label: 0
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert("RGB"), # string path => image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
def main():
import visdom
import time
viz = visdom.Visdom()
db = Pokemon("pokemon", 224, "train")
x, y = next(iter(db))
print("sample: ", x.shape, y.shape, y)
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=4)
for x, y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()
代码中定义的子类Pokemon继承自Dataset类,重写了父类的__len__、__getitem__方法。
在__init__方法中首先读取图片路径,构建了标签字典name2label,接下来将图片路径以及对应的标签写入到csv文件中,通过保存在csv文件中的路径与标签信息,划分数据集,csv文件内容如下:
在__len__方法中返回了数据集的大小。
在__getitem__方法中进行图像缩放、图像旋转、图像去中心、转换到tensor、归一化等数据增强。
另外,在Pokemon类中实现了denormalize方法,也就是反归一化,因为在对图像进行归一化处理后,在visdom显示图像的时候,可见度不高,因此denormalize方法仅在visdom显示图像的时候调用。
在main()函数中,进行数据集的可视化,可以看到继承自Dataset类的Pokemon类可以通过迭代器iter进行访问,并通过visdom进行可视化展示,另外,通过DataLoader类实现了对数据集的加载,在visdom中以32个batch进行加载。
验证效果
启动两个终端,分别执行如下两条命令:
python -m visdom.server
python pokemon.py
复制第一个终端中visdom链接 http://localhost:8097到浏览器
可以看到,32个batch的图片与标签在浏览器中展示。