学习视频:https://www.bilibili.com/video/BV1hE411t7RN?p=1,内含环境搭建
Pytorch有两个读取数据的方式:
- 使用Dataset
- 使用DataLoader
本文先介绍第一种——Dataset
Dataset与DataLoader区别
- Dataset:提供一种方法,去获取数据及其对应的label值
- DataLoader:提供一种方法,可以以特定的形式打包数据
数据集
接下来使用的数据集下载地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip
文件结构:(本人将文件夹重新命名为"dataset")
dataset
├── train
│ ├── ants
│ └── bees
└── val
├── ants
└── bees
使用torch.utils.data下Dataset读取数据
在处理数据前,首先要做的就是读取数据,torch提供了对应读取数据方法来适配其他torch的处理数据方法。 代码如下:
from torch.utils.data import Dataset # 导入Dataset后可以使用“help(Dataset)查看官方文档”
from PIL import Image # 借助PIL库导入数据图片
import os # 借助os库来用路径读入数据
class Mydata(Dataset): # 根据官方文档,自己创建的类必须继承Dataset
def __init__(self,root_dir,label_dir): # 初始化操作,传入图片所在的根目录路径(root_dir)和label的路径(label_dir)获得一个路径列表(img_path)
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir) # 用join把路径拼接一起可以避免一些因“/”引发的错误
self.img_path = os.listdir(self.path) # 将该路径下的所有文件变成一个列表
def __getitem__(self,idx) # 使用index(简写为idx)获取某个数据
img_name = self.img_path[idx] # img_path列表里每个元素就是对应图片文件名
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) # 获得对应图片路径
img = Image.open(img_item_path) # 使用PIL库下Image工具,打开对应路径图片
label = self.label_dir # 本数据集label就是文件名,如“ants”(虽然命名为dir看似路径,实则视作字符串会更容易理解)
return img,label # 返回对应图片和图片的label
# 调用
root_dir = "/content/drive/MyDrive/Pytorch学习/dataset/train"
ants_label_dir = "ants"
ants_dataset = Mydata(root_dir,label_dir)
ants_dataset[0]
结果:返回的一个元组,元组中有两个数据,一个是集合<…>部分,一个是字符串"ants"
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=441x500 at 0x7FA193C303D0>,
'ants')
因此可以这样赋值,即可显示图片
img,label = ants_dataset[0]
img.show()
小技巧:
train_dataset = ants_dataset + bees_dataset # 将两个数据集拼接起来