课程学习笔记,课程链接
学习笔记同步发布在我的个人网站上,欢迎来访查看。
一、 Dataset 和 Dataloader
Pytorch 读取数据主要涉及两个类:Dataset
和 Dataloader
数据可类比为“垃圾”,不同数据是不同种类的垃圾,这里蓝色是可回收垃圾。
Dataset能够把垃圾中的可回收垃圾即蓝色块给挑选出来,并对其进行编号,供后续网络的使用。
而数据进入网络不会是一个个送进去,在送进去之前会进行打包,比如以一次多个的形式把数据输入进网络。
总结:
Dataset
提供了一种方式去获取每个数据及其label并告诉我们总共有多少的数据。Dataloader
为数据进行打包,给要训练的网络提供不同形式的数据。
二、数据集初识
数据集 蚂蚁蜜蜂分类 下载链接:https://download.pytorch.org/tutorial/hymenoptera_data.zip
解压打开查看:
分为训练数据集和验证数据集。
两个文件夹都分别有分类好的蚂蚁和蜜蜂的图片:
这是一个用于对蚂蚁和蜜蜂进行二分类的数据集。
三、Dataset类初识
打开jupyter,新建一个名为 read_dataset的notebook。输入下图所示代码:
可以看到
Dataset的使用说明表示任何数据集应该继承Dataset,并改写成员函数:__getitem__
和__len__
(可选)。
这里将数据集放到工程目录下,这样就可以用相对路径进行访问了:
代码:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
# get relative address of ants pictures
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
"""
:param idx: img_name
:return: object of data,label
"""
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_label_dir ="ants"
bees_label_dir ="bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[0]
print(label)
img.show()
这里就是对数据集进行简单读取,可以通过索引来对指定的数据进行图片信息和 label 读取,输出如下图所示: