本文参考b站up主:我是土堆
机器学习的第一步是构建数据集,构建数据集主要用到下面两个:
- Dataset
作用:提供一种方式去获取数据及其label
思路:如何获取每一个数据及其label
、告诉我们总共有多少数据 - Dataloader
作用:为后面的网络提供不同的数据形式
思路:从Dataset
中取数据,怎么取看自己的参数设定
常用的数据集基本是由(数据,label
)的形式构成的。
label
简单的话,可以统一为数据文件夹的名字。见下图:
若label
比较麻烦,可以新建一个文件夹,里面为txt
文档,文档名字与数据名字一致,文档内容为该数据的label
。见下图:
Dataset
首先明确,这一步要得出的结果是表示出数据集,即(数据,label
)的形式。
我们需要重写下面这个类:
from torch.utils.data import Dataset
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
def __getitem__(self, idx):
return img, label
def __len__(self):
return len(...)
结合上面给出的文件夹,我们对该类补充之后如下:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # self:使其成为该类中的全局变量
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir) # 拼接路径
self.img_path = os.listdir(self.path) # 将该路径对应文件夹下的文件以列表形式存储
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
主函数部分代码为:
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir) # 这样就可以返回(数据, label)的形式了
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset # 数据集拼接
我们以ants_dataset
为例,在控制台看看它里面存储了什么:
>>> ants_dataset = MyData(root_dir, ants_label_dir)
>>> ants_dataset[0]
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512 at 0x23C4BC3D5E0>, 'ants')
这样,第一步任务就做完了,我们的数据集已经存储在train_dataset
中。