Pytorch之Dataset和Dataloader(加载数据)

#Pytorch学习


前言。。

首先深度学习需要数据集,而数据集的处理离不开Dataset类和DataLoader类。
简单的区别一下Dataset和Dataloader:
Dataset是一个需要我们实现的抽象类,通过我们相关实现,表示数据集,实现数据集的具体功能。
而Dataloader,我们一般直接调用,调用来加载数据集。

先说一下常见的数据组成形式:

  1. 一个文件夹对应一个类,data根据类被分成不同的文件夹
  2. 一个图片文件对应名称为标签,
  3. 两个文件夹,一个文件夹存图片,另一个存标签。有的标签过于复杂,适合单独存放,而不放入名字里

一、Dataset

1.Dataset 是torch.utils.data.dataset包里一个抽象类,可以用来创建数据集,我们需要重写子类来完成数据集的创建。
2.其中最重要的两个方法就是len和getitem,也是子类必须要重写的方法,len函数返回数据集长度,getitem用来查找数据和标签

1.针对不同的文件夹存放不同的类数据

在这里插入图片描述
在这里插入图片描述
这里图片的名字并没有很大的意义。在这里插入图片描述

代码实现:

from torch.utils.data import Dataset
from PIL import Image   #这里用来读取图片数据
import os

class MyData(Dataset):  # 这里定义了一个MyData类继承Dataset读取数据

    def __init__(self, root_dir, label_dir):        #进行初始化数据
        self.root_dir = root_dir        #比如 "data/train"  相对路径
        self.label_dir = label_dir      #这里可以是 "ant"
        self.path = os.path.join(self.root_dir,self.label_dir)      #利用一个函数将其合并为ant标签对应的路径
        # os.listdir()会将path路径下对应的所有文件的“文件名”转化成列表中的一个个元素
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):  #实现这个方法是为了获取第idx个文件对应的数据和标签的
        img_name = self.img_path[idx]   #这个列表就可以返回名字呢
        img_item_path = os.path.join(self.path,img_name)    #将其之前的路径和文件名拼起来就是最终路径了
        img = Image.Open(img_item_path)  #数据读取到img变量里了
        label = self.label_dir  #标签就是上层文件夹名
        return img,label       #返回img,label完成重写函数任务

    def __len__(self):
        return len(self.img_path)  #列表的长度就是数据的长度

2.再针对另一种情况,image和label分开

在这里插入图片描述
其实很简单再引入个label_path变量就可以了。

二、torchvision中数据集的使用

pytorch官网里可以发现,这里提供了非常多常用的数据集,比如Minist,CIFAR等。
在这里插入图片描述
使用 看代码就可以明白了

import torchvision #一个处理图像的库,包括数据集加载和预处理
from torch.utils.tensorboard import SummaryWriter
dataset_tans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor() #这里用compose生成一个合成transforms工具预处理数据
])
#这里的几个参数,root是下载目录,
# train的true或者false表示是训练集还是测试集,
# download一般设置成true进行下载,如果已经下载好了可以不设置这个参数
#同时利用transform参数将数据转化成tensor类型
#ctrl+p可以查看调用的函数需要什么参数
train_set = torchvision.datasets.CIFAR10(root="./dataset",train = True,transform= dataset_tans,download = True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train = False,transform= dataset_tans,download = True)
writer = SummaryWriter("logs")
#print(test_set[0])
for i in range(10):
    img,target = test_set[i]
    writer.add_image("嘻嘻",img,i)

writer.close()

二、Dataloader的使用

from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import DataLoader

#数据设置与下载
#这里将数据转化成tensor格式
test_data = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor())

#加载数据

#调用Dataloader类必须先传入dataset,之前设置的数据
#然后还要设置batch_size
#shuffle 为True,每个epoch会刷新数据的顺序
#num_works = 0 ,会单进程处理数据
#drop_last 为True batch处理的数据即便无法整除也不会丢弃
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)

writer = SummaryWriter("loader")

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs,targets = data  #这里的imgs是根据batch读的64张图片,target也是64个标签
        writer.add_images("Eposch.{}".format(epoch),imgs,step)    #注意这里是add_images
        step += 1       #step++

writer.close()

在这里插入图片描述


  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值