pytorch加载数据

参考:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
本文是上面视频的笔记,up主讲的特别详细,推荐观看。
在pytorch中加载数据主要涉及到两个类:Dataset 和 Dataloader
Dataset :提供一种方式去提取数据并得到label
Dataset:对数据进行打包送到网络中去,为后面的网络提供不同的数据形式。
下面是代码及说明:

from torch.utils. data import Dataset

在这里插入图片描述
可看到说明,Dataset是一个抽象类,我们重写Dataset时要继承这个类,所有的子类都应该重写__getitem__()方法,这个方法作用是获取数据及对应的labe。同时我们可以选择性地去重写__len__方法,其作用是获取数据集长度。

例子:

这里我使用的是猫狗二分类的数据集,如图:
在这里插入图片描述

from torch.utils. data import Dataset
from PIL import Image
import os

class Mydataset(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir    ##根目录
        self.label_dir = label_dir  ##标签,也就是文件名
        self.path = os.path.join(self.root_dir,self.label_dir)  ##拼成一个完整的目录
        self.img_path = os.listdir(self.path) ##获得图片的一个list

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  ##得到单个图片的名字
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  ##得到单个图片的路径
        img = Image.open(img_item_path) ##图片数据
        label = self.label_dir ##标签
        return img, label
    def __len__(self):
        return len(self.img_path)

root_dir="D:/猫狗大战/data/train"
cat_label_dir = "cat"
dog_label_dir = "dog"
cat_dataset = Mydataset(root_dir,cat_label_dir)
dog_dataset = Mydataset(root_dir,dog_label_dir)
img, label = cat_dataset[1]
img.show()
print(label)

img, label = dog_dataset[1]
img.show()
print(label)

输出结果:
cat
dog
在这里插入图片描述
写给自己,另外,可以参考这篇博客:
https://ptorch.com/news/215.html
fastai也可以关注以下

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

通信仿真爱好者

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值