PyTorch中如何读取数据(Dataset类的使用)

在pytorch中如何读取数据主要有两个类。

分别是Dataset和Dataloader。
dataset可以理解为:提供一种方式去获取数据及其label(标签)。
可以实现(1)如何获取每一个数据及其label;(2)总共有多少数据。这两个功能。

dataloader可以理解为:为后面的网络提供不同的数据形势。

Dataset类怎么去用?

from torch.utils.data  import Dataset

这段代码可以理解为:从torch大工具箱里面utils常用的工具区,关于数据的data区的。

可以使用help()函数查看,在jupyter或者pycharm控制台里面查询。
也可以直接在jupyter里输入Dataset??,直接可以查询。

Dataset的运用


class MyData( Dataset ) :  //创建一个class(MyData)继承Dataset类

class MyData( Dataset ) :  

def __init__(self):  //初始化类,比如说我们要根据这个类去创建一个特例的时候,它就要运行的一个函数。这个函数它一般会为整个class提供一个全局变量。为后面的一些函数提供一些所需要的量。

def __init__(self): 

def __getitem__(self, item) : 
它默认为item,我们改为def __getitem__(self, idx):  // idx可以看作一个编号

def __getitem__(self, item) : 


如果我们要通过这个idx(索引)来获取图片的地址的话,首先要获取这些图片的列表(list)。
如果需要获取所有图片的地址的话,我们就需要用到os(python中关于系统的一个库)

dir_path = ""  // ""中输入所有图片文件夹地址,我使用全地址报错了,改用相对地址后没问题
import os  //使用os
img_path_list = os.listdir(dir_path)  //将文件夹中的所有图片变成列表

如果我们要使用idxa去获取想要的图片的话,首先就要去创建图片地址的列表

def __init__(self, root_dir, label_dir)

使用python console验证。

import os
root_dir = ""    //  “”中输入放图片文件上一个文件的地址
label_dir = ""   // “”中输入放图片的地址
path = os.path.join(root_dir, label_dir)  //join这个给函数的作用就是在root_dir, 
label_dir两个地址之间添加一个\,将这两个路径进行拼接

接着,

def __init__(self, root_dir, label_dir)
      self.root_dir = root_dir     //为什么用self,我们知道一个函数中的变量是不能传
递给另外一个函数的变量的。而这个self,它可以把self指定的一个变量给后面的函数使用。它就
相当于指定了一个类中的全局变量。
      self.label_dir = label_dir
      self.path = os.path.join(self.root_dir, self.label_dir) //获得图片的路径地址
      self.img_path = os.listdir(self.path)  // 获得所有图片列表

如果我们想验证这个函数,可以在python console中验证。

如果要获取所有图片中某一个图片的话,

def __getitem__(self, idx):
     img_name = self.img_path[idx]  //  名称,从list里面读取遥感对应位置, 加self是
指全局的,引用上面的 self.img_path
     img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  //获
取某一个图片的路径

自此可以使用python console验证。

接着,可以使用import PIL import Image来读取图片

    img = Image.open(img_item_path)  //读取图片
        label = self.label_dir
        return img, label
    def __len__(self):          //查看这个列表的长度有多长
        return len(self.img_path)  


怎么读取电脑中的一张图片
from PIL import Image   //一个读取图片的方法
可以先在Python控制台进行调试。


 

from PIL import Image
img_path = ""  //获取图片地址 “”中输入图片地址
img = Image.open(img_path)
img.show()  //显示该图片

全部代码

from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

 # "len(train_dataset)"指令可以在Python console中查看train_dataset数据集中有多少个元素。
img, label = train_dataset[230]
img.show()

  • 7
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

晓亮.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值