一 、Dataset 的使用
Dataset 告诉我们如何获取每一个数据及其lable,同时使用Dataset可以查看有多少个数据
使用Dataset从数据集中读取数据和展示数据
from torch.utils.data import Dataset
import cv2
# python中关于系统的库
import os
class MyData(Dataset):
# 初始化 self指定了类中的全局变量
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
# 地址拼接
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 读取图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset+bees_dataset
二、Dataloader 的使用
DataLoader是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。通俗理解就是主要是把获取到的数据(Dataset)进行打包,为后面的网络提供不同的数据形式。