Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】

数据载入由什么组成?

数据载入由dataset和dataloader组成。
dataset:提供一种方式去获取数据及其label
dataloader: 为后面网络提供不同的数据形式

1. Dataset的功能

dataset主要为了实现两个功能
1.如何获取每个数据及其label
2.告诉我们总共有多少的数据

2. Dataset代码:

1) 查看官方文档解释

首先,在anaconda prompt中输入如下代码,打开jupyter环境

conda activate <pytorch环境名称> #激活pytorch环境 
jupyter notebook #打开jupyter

然后,在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import Dataset
help(Dataset)

2) pycharm中进行程序编写

(1)下载数据集
蚁蜜蜂分类数据集
https://download.pytorch.org/tutorial/hymenoptera_data.zip
(2)
建立dataset文件,并将数据集放入程序下
在这里插入图片描述
(3)编写数据集载入程序,实现dataset两个功能,第一,如何获取每个数据及其label
第二,告诉我们总共有多少的数据。



from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms

class mydata(Dataset):

    #设置全局参数
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        # 获得图片的路径地址
        self.path=os.path.join(self.root_dir,self.label_dir) #os.path.join()函数用于路径拼接文件路径,可以传入多个路径。如果不存在以’/’开始的参数,则函数会自动加上
        # 获得图片的所有列表
        self.img_path=os.listdir(self.path)  #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

    #获取每一个图片
    def __getitem__(self, item):
        #获取单张图片名称
        img_name=self.img_path[item]
        #获取单张图片相对路径
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        #读取单张图片
        img=Image.open(img_item_path)
        img=img.resize((256,256),Image.ANTIALIAS) #统一图片尺寸
        trans = transforms.ToTensor() #转换为tensor类型
        img_tensor = trans(img)#转换为tensor类型


        #获取lable
        label=self.label_dir
        return img_tensor, label

    #列表有多长
    def __len__(self):
        return len(self.img_path)

3. Dataloader的功能

dataloader的功能是为了实现从dataset中取数据,例如,每次取多少数据?,数据集是否打乱?,加载过程是单进程还是多进程?,如果最后剩余数据不足一次需要获取数据,剩余数据是否舍弃。

4. Dataloader代码

1) 查看官方文档解释

在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import DataLoader
help(DataLoader)

2) pycharm中进行程序编写

新建.py文件,写入以下内容

from dataset import mydata
from torch.utils.data import DataLoader
import torch

#准备测试数据集
root_dir= "dataset/val"
bees_label_dir="bees"
test_dataset=mydata(root_dir,bees_label_dir)#输入数据集路径

#数据载入
test_loader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

for data1 in test_loader:
    imgs,labs=data1
    print(imgs.shape)
    print(labs)

输出如下内容则表示数据载入成功
在这里插入图片描述

感谢: PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
视频网址:https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值