Dataset和DataLoader使用(pytroch框架入门])

1、 Dateset

  • 数据的位置
# dataset 提供一种方式取获取数据及其label
# 1、如何获取每一个数据机器lable
# 2、告诉我们总共有多少数据

from torch.utils.data import Dataset
from PIL import Image    #目的读取图片
import os

class MyData(Dataset):
    #类的实例化,为后面提供全局变量使用
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(root_dir,label_dir)#label的地址,下面有很多同一个label的图片
        self.img_path=os.listdir(self.path)#给label写成列表的形式即[0]为第一张图片


    #返回图片的属性和所属的标签
    def __getitem__(self, idx):
        img_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#每张图片的地址
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label

    #数据集的长度
    def __len__(self):
        return len(self.img_path)

root_dir=r"pytorch_learning\dataset\train"
daisy_label_dir=r"daisy"
roses_label_dir=r"roses"
daisy_dataset=MyData(root_dir,daisy_label_dir)
roses_dataset=MyData(root_dir,roses_label_dir)

train_dataset=daisy_dataset + roses_dataset

2、torchvision中的数据集的使用

在这里插入图片描述
在这里插入图片描述

  • 这里顺便拿一个图像分类的例子来介绍
    在这里插入图片描述
  • 数据集介绍
    在这里插入图片描述
  • 数据集使用
    在这里插入图片描述
  • 直接下载在线的训练集
    在这里插入图片描述
  • 下载的数据集+transform应用
import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

train_set=torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",transform=dataset_transform,train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",transform=dataset_transform,train=False,download=True)

writer=SummaryWriter("p10_logs")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("test_set",img,i)

writer.close()

在这里插入图片描述

3.DataLoader使用

  • 将数据加载我们的神经网络中(具体怎么加载?)
  • 首先就是去看文档
import torchvision
from torch.utils.data import DataLoader

# 数据集准备
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([
       torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",
                                         train=True,
                                         transform=dataset_transform,
                                         download=True
)

# 数据集的加载,====下面的括号里面说的是,一把从test_data中取出4个进行打包,再返回imgs和target
test_loader=DataLoader(dataset=test_data,
                                    batch_size=64,#一把抓四张图片
                                    shuffle=False,#每一轮epoch,本次和上一次是不一样的
                                    num_workers=0,#加载图片使用多少个线程,一般在linux下设置的,windows下默认为0
                                    drop_last=True#设置为“True”将删除最后一个不完整的批处理
)

#测试数据集中第一张图片及target
img,target=test_data[0]
print(img)
print(target)

writer=SummaryWriter("dataloader_logs")
for epoch in range(2):
    step=0
    for data in test_loader:
        imgs,targets=data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("droplast_epoach:{}".format(epoch),imgs,step)#这边是一坨图片
        step=step+1;

writer.close()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

栋哥爱做饭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值