当我们配置好Pytorch环境后(3.6版本的Python),第一步先导入相关的库。
from cProfile import label
from tkinter.messagebox import showerror
from torch.utils.data import Dataset #读数据
from PIL import Image #图像处理
from torch.utils.tensorboard.summary import image
import os
之后再定义这个类:
在这里“ants”的数据集存储路径为:D:\Pychram_Project\Learn_Pytorch\pythonProject\dataset\train\ants
所用到的方法在注释中都有说明,在这里文件夹的名称就是label,这里还将路径名和标签名拼接。
class MyData(Dataset): #Dataset可获取数据集及其label
def __init__(self,root_dir,label_dir): #初始化函数
# root_dir="dataset/train"
# label_dir="ants"
self.root_dir=root_dir #self可把当前的变量转换成全局变量 转换文件夹路径
self.label_dir=label_dir #转换标签路径
self.path=os.path.join(self.root_dir,self.label_dir) #路径拼接,在不同系统中斜杠的意不同,用此函数可以避免歧义。
self.img_path=os.listdir(self.path) #转换成列表形式
def __getitem__(self, index): #读取其中的(列表中的对应位置)图片
img_name=self.img_path[index]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name) #把名字接在路径后面
img=Image.open(img_item_path)
label=self.label_dir #这是因为文件夹的名称就是label,所以这里直接让label=label_dir
return img,label
def __len__(self):
return len(self.img_path) # 返回列表的长度
测试部分如下:
其中,蜜蜂数据集的存储路径为D:\Pychram_Project\Learn_Pytorch\pythonProject\dataset\train\bees
我们在这里想要读取两个数据集就要创建两次对象。
#测试:
root_dir="dataset/train"
ants_label_dir="ants" #读取蚂蚁数据集
ants_dataset=MyData(root_dir,ants_label_dir)
print(ants_dataset[0]) #第一张蚂蚁图片的信息
img,label=ants_dataset[1] #可以同时接收两个返回值
print(img)
print(label)
#img.show()
bee_label_dir="bees" #读取蜜蜂数据集
bee_dataset=MyData(root_dir,bee_label_dir)
img2,label2=bee_dataset[0]
print(label2)
#img2.show()
为了避免多次创建对象,可以将其进行合并,这里的的图片和标签是在一起的,文件夹名字就是标签名。以后也可以把label都存到一个text文档里,再把所有的text文档整理成一个文件夹。
#数据集合并(拼接): #可用于label和img分离,创建一个label文件夹,里面用txt文档存储label,再把img和label拼接
train_dataset=ants_dataset+bee_dataset
print(len(train_dataset))
img3,label3=train_dataset[123]
print(label3)
#img3.show()
img4,label4=train_dataset[124]
print(label4)
img4.show()
(教学内容来自b站小土堆)