PyTorch学习笔记1:Dataset和Dataloader

PyTorch学习笔记1:Dataset和Dataloader

1. 加载数据集

1.1 Dataset

提供一种方式获取数据及其标签(label)

  • 如何获取每一个数据及其label
  • 告诉我们总共有多少的数据

代码示例:
这里使用的数据集为蚂蚁蜜蜂数据集:

数据集下载地址
密码:5suq

定义Dataset类

from torch.utils.data import Dataset # PyTorch的数据集模块
from PIL import Image # 图像加载模块 
import os # 路径操作模块


# 继承Dataset类,重写__getitem__(self, index), __len__(self)方法
class MyData(Dataset):
  # 初始化函数
  def __init__(self, root_dir, label_dir):
    self.root_dir = root_dir # 包含图像的根目录
    self.label_dir = label_dir # 包含类别标签的子目录 
    self.path = os.path.join(self.root_dir, self.label_dir) # 标签的完整路径
    self.img_path = os.listdir(self.path) # 图像文件名列表

  # 获取数据集中单个样本的函数
  def __getitem__(self, index):
    img_name = self.img_path[index] # 获取图像文件名
    img_item_path = os.path.join(self.path, img_name) # 获取图像完整路径 
    img = Image.open(img_item_path) # 加载图像
    label = self.label_dir # 标签是子目录名称
    return img, label # 返回图像和标签

  # 获数据集大小函数
  def __len__(self):
    return len(self.img_path) # 图像数量

实例化类并使用:

# 实例化并使用
ants_dataset = MyData('data/hymenoptera_data/train', 'ants')
img, label = ants_dataset[1]
print('数据集大小为:', len(ants_dataset))
print("标签为:", label)
img.show()

运行结果

1.2 Dataloader

为后面的网络提供不同的数据形式(迭代器)

采样并以迭代的形式提供数据
分batch、打乱之类的操作:
在这里插入图片描述

Dataloader常用参数介绍

参数数据类型解释
datasetDataset加载数据的数据集
batchsizeint每批次加载多少个样本(默认值:1)
shufflebool是否打乱
num_workersint是否多进程读取,默认是0表示主进程,-1表示所有
drop_lastbool当样本数不能被batchsize整除时, 是否舍弃最后一批数据1

  1. 为了理解drop_last,我们需要搞清楚epoch,iteration和batch_size:
    epoch:所有训练样本都经过了模型一次训练,称为一个epoch
    Iteration:一批样本输入到模型中,称为一个Iteration
    batch_size: 一批样本(bath)的大小, 决定一个Epoch有多少个Iteration ↩︎

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值