Dataset类 与 Dataloader类
- Dataset类作用:从海量数据中获取目标数据以及它们对应的标签(label),并为每个数据加上索引,可以通过索引来直接找到数据。(提供一种获取数据及其标签的方式)
- Dataloader类作用:将指定的目标数据以及对应的label进行打包,并发送给后面的神经网络(为后面网络提供不同的数据形式)
- 代码:
from torch.utils.data import Dataset
(引入Dataset类)
Dataset类:所有数据集都要继承这个类,并且重写该类中的__getitem__方法,该方法作用是每个获取数据及其对应label,用户也可以选择性重写__len__方法(其作用:返回数据的长度)
Dataset代码实战
一个练手的数据集 密码: 5suq
- 下载数据集,将其解压在对应的python工程目录下
- 代码如下(代码中使用了opencv这个第三方库,请自行下载)
from torch.utils.data import Dataset
import os
import cv2
class MyData (Dataset):
def __init__(self, types, root, label, image):
self.root_dir=root
self.label_dir=label
self.type=types
self.path_label=os.path.join(root, types, label) #获得存有所有标签的路径
self.path_image=os.path.join(root, types, image) #获得存有所有图片的路径
self.imgs=os.listdir(self.path_image) #获得所有图片
self.labs=os.listdir(self.path_label) #获得所有标签
def __getitem__(self, idx):
img_name=self.imgs[idx] #获得对应的标签文件名
label_name=self.labs[idx] #获得对应图片文件名
img_path=os.path.join(self.path_image, img_name)
label=os.path.join(self.path_label,label_name) #获得图片对应label的文件的地址
with open(label, "r", encoding="utf-8") as labs: #将图片对应label文件中的label读出来
lab=labs.read()
imgs=cv2.imread(img_path) #将对应图片读出来
return imgs, lab
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
cv2.namedWindow("img", cv2.WINDOW_NORMAL)
path="dataset"
ty="train"
labels1="ants_label"
labels2="bees_label"
img_name1="ants_image"
img_name2="bees_image"
ants_dataset=MyData(ty, path, labels1, img_name1) #蚂蚁训练集
bees_dataset=MyData(ty, path, labels2, img_name2) #蜜蜂训练集
train_dataset=ants_dataset+bees_dataset #合并训练集
img, lab1=train_dataset[127]
cv2.imshow('img', img) #读出图片
print(lab1) #打印出对应图片的标签
print(len(ants_dataset)) #获得数据集长度
cv2.waitKey(0)
cv2.destroyAllWindows()
```
```