Dataset类
Dataset
类是数据加载的核心组件之一。它是一个抽象类,用户需要通过继承这个类并实现其中的两个方法:__len__
和__getitem__
。
1. 数据集结构
-
数据分为训练集和测试集,训练集和测试集中分别有两个文件夹,文件夹名称为数据的类别,每个类别文件夹下有多个数据
2. 搭建框架
-
需要继承
torch.utils.data
中的Dataset
类,并重写两个魔法方法__getitem__
__len__
3.__init__
-
在初始化函数中完成对图像数据名称的获取,用于对后期数据的加载
4.__getitem__
-
getitem要根据给定的索引返回一个样本。通常会包含数据、标签,必要时还会应用数据变换
5.__len__
-
len方法用于返回加载的数据集中有多少个数据
完整代码
from torch.utils.data import Dataset
import os
from PIL import Image
class Mydata(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dit = root_dir # 根目录 ./data/traim
self.label_dir = label_dir # 类别 ants
self.image_path = os.path.join(self.root_dit, self.label_dir) # './data/train/ants
self.image_path_list = os.listdir(self.image_path) # 获取ants下的所有文件的名称
def __getitem__(self, index):
image = self.image_path_list[index] # 通过编号获取图像的名称
image_item_path = os.path.join(self.image_path, image) # 拼接出图像的具体路径
img = Image.open(image_item_path)
label = self.label_dir
return img, label # 返回图像数据、标签
def __len__(self):
return len(self.image_path_list)
两个Dataset实例求和的数据集
-
Dataset类的实例化支持求和操作,首先需要设置len方法,两个Dataset的实例的求和是将
__len__
方法中的返回的计算长度的列表作为数据集相加,从而得到新的数据集