在pytorch中如何读取数据主要有两个类。
分别是Dataset和Dataloader。
dataset可以理解为:提供一种方式去获取数据及其label(标签)。
可以实现(1)如何获取每一个数据及其label;(2)总共有多少数据。这两个功能。
dataloader可以理解为:为后面的网络提供不同的数据形势。
Dataset类怎么去用?
from torch.utils.data import Dataset
这段代码可以理解为:从torch大工具箱里面utils常用的工具区,关于数据的data区的。
可以使用help()函数查看,在jupyter或者pycharm控制台里面查询。
也可以直接在jupyter里输入Dataset??,直接可以查询。
Dataset的运用
class MyData( Dataset ) : //创建一个class(MyData)继承Dataset类
class MyData( Dataset ) :
def __init__(self): //初始化类,比如说我们要根据这个类去创建一个特例的时候,它就要运行的一个函数。这个函数它一般会为整个class提供一个全局变量。为后面的一些函数提供一些所需要的量。
def __init__(self):
def __getitem__(self, item) :
它默认为item,我们改为def __getitem__(self, idx): // idx可以看作一个编号
def __getitem__(self, item) :
如果我们要通过这个idx(索引)来获取图片的地址的话,首先要获取这些图片的列表(list)。
如果需要获取所有图片的地址的话,我们就需要用到os(python中关于系统的一个库)
dir_path = "" // ""中输入所有图片文件夹地址,我使用全地址报错了,改用相对地址后没问题
import os //使用os
img_path_list = os.listdir(dir_path) //将文件夹中的所有图片变成列表
如果我们要使用idxa去获取想要的图片的话,首先就要去创建图片地址的列表
def __init__(self, root_dir, label_dir)
使用python console验证。
import os
root_dir = "" // “”中输入放图片文件上一个文件的地址
label_dir = "" // “”中输入放图片的地址
path = os.path.join(root_dir, label_dir) //join这个给函数的作用就是在root_dir,
label_dir两个地址之间添加一个\,将这两个路径进行拼接
接着,
def __init__(self, root_dir, label_dir)
self.root_dir = root_dir //为什么用self,我们知道一个函数中的变量是不能传
递给另外一个函数的变量的。而这个self,它可以把self指定的一个变量给后面的函数使用。它就
相当于指定了一个类中的全局变量。
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir) //获得图片的路径地址
self.img_path = os.listdir(self.path) // 获得所有图片列表
如果我们想验证这个函数,可以在python console中验证。
如果要获取所有图片中某一个图片的话,
def __getitem__(self, idx):
img_name = self.img_path[idx] // 名称,从list里面读取遥感对应位置, 加self是
指全局的,引用上面的 self.img_path
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) //获
取某一个图片的路径
自此可以使用python console验证。
接着,可以使用import PIL import Image来读取图片
img = Image.open(img_item_path) //读取图片
label = self.label_dir
return img, label
def __len__(self): //查看这个列表的长度有多长
return len(self.img_path)
怎么读取电脑中的一张图片
from PIL import Image //一个读取图片的方法
可以先在Python控制台进行调试。
from PIL import Image
img_path = "" //获取图片地址 “”中输入图片地址
img = Image.open(img_path)
img.show() //显示该图片
全部代码
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
# "len(train_dataset)"指令可以在Python console中查看train_dataset数据集中有多少个元素。
img, label = train_dataset[230]
img.show()