源码
第一种情况MNIST_Training_MyDataset
第二种情况MNIST_Training_With_No_Split
简介
本文主要探讨获取一个数据集的两种情况,以手写数据集为例
- 以文件格式划分好了训练集与测试集
- 文件没有划分测试集与训练集,需要通过代码进行划分
1.文件格式已经划分好了训练集与数据集
这种数据集我们主要就是要通过自定义的类将文件格式的数据转化为可以进行训练的数据集,主要通过以下几步
- 创建自定义的类,并继承
Dataset
类 - 重写
Dataset
的三个方法__init__
,用于根据数据集地址与数据转化类型来进行对数据初始化__len__
,用与获取数据集的长度__getitem__
,用于根据下标来获取数据集中的对应元素,并且返回图片与标签的二元组
基本代码结构为
class MyDataset(Dataset):
def __init__(self, root_path, transform=None):
pass
def __len__(self):
pass
def __getitem__(self, index):
pass
__init__
定义
__init__
主要实现一件事
将给出的数据集地址转化并保存为一个数据集列表
我们的文件结构为,下面的代码在必要处都给出了注释,读者可以自行阅读
def __init__(self, root_path, train, transform=None):
self.root_path = root_path
# 判断变化规则
self.transform = transform
# 判断是否是训练集
if train:
self.data_path = os.path.join(self.root_path, 'training')
else:
self.data_path = os.path.join(self.root_path, 'testing')
self.img_paths = []
self.labels = []
# 遍历每个子文件夹(标签)
for label_dir in os.listdir(self.data_path):
label_path = os.path.join(self.data_path, label_dir)
if os.path.isdir(label_path): # 只处理目录
# 遍历子文件夹中的所有图像文件
for img_name in os.listdir(label_path):
single_img_path = os.path.join(label_path, img_name)
# 将单个图片路径添加到img_paths中
self.img_paths.append(single_img_path)
# 将图片对应的标签添加到labels中
self.labels.append(int(label_dir))
__len__
定义
有了__init__
的定义,我们只需要返回len(img_paths)
即可
def __len__(self):
return len(self.img_paths)
__getitem__
定义
__getitem__
主要实现两个功能
- 从数据集中获取对应下标的图片,并且转化为给出的
transform
格式 - 获取数据集中对应下标的标签
def __getitem__(self, index):
img_path = self.img_paths[index]
img_PIL = Image.open(img_path).convert('L')
# 如果变化规则不为空
if self.transform is not None:
img_tensor = self.transform(img_PIL)
else:
img_tensor = img_PIL
# 确定对应下标的标签
label = self.labels[index]
return img_tensor, label
使用
from My_Dataset import *
root_dir = '../datasets/mnist_png'
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor()
])
my_train_dataset = MyDataset(root_dir, train=True, transform=transform)
my_test_dataset = MyDataset(root_dir, train=False, transform=transform)
2.文件没有划分测试集与训练集,需要通过代码进行划分
具体流程与第一种情况类似,区别是
- 额外定义了一个
spilt_dataset
划分数据集 - 在
__init__
中对完整数据集进行划分,划分为训练集与数据集 - 额外定义类
Subset
用于对划分结束的训练集与数据集规范化
__init__
定义
具体的数据集路径为
def __init__(self, root_path, transform=None):
# 记录成员变量
self.root_path = root_path
self.transform = transform
# 将图片与标签列为list
self.imgs_path = []
self.labels = []
# 获取root_path下的所有图片
for label_path in os.listdir(root_path):
img_path = os.path.join(root_path, label_path)
if os.path.isdir(img_path):
for img_name in os.listdir(img_path):
pre_img_path = os.path.join(img_path, img_name)
self.imgs_path.append(pre_img_path)
self.labels.append(label_path)
# 获取随机的random参数
random.seed(random.seed)
# 创建完整的数据集,内容为将图片和labels一一对应并且列为list
data = list(zip(self.imgs_path, self.labels))
# print(data)
# 将数据集打乱
random.shuffle(data)
# 设置训练集数据集划分比例
spilt_size = 0.8
split = int(len(data) * spilt_size)
# 根据比例划分训练集与数据集
self.train_data = data[:split]
self.test_data = data[split:]
相较于第一种方法,新增了许多代码,因为第一种方法从文件名就能知道训练集与数据集,而我们这种只能通过将完整的数据集打乱并且按比例取出训练集与数据集
__len__
定义
与第一种方法没有区别
def __len__(self):
return len(self.imgs_path)
__getitem__
定义
def __getitem__(self, index):
img_path = self.imgs_path[index]
label = self.labels[index]
# 根据图片地址获取图片信息,并且转化为灰度图像
img = Image.open(img_path).convert('L')
if self.transform is not None:
img = self.transform()
return img, label
也无区别
spilt_dataset
定义
这个方法主要是将划分完成的训练集与测试集进行返回,需要调用Subset
类
def spilt_dataset(self):
train_data = Subset(self, self.train_data, transform=self.transform)
test_data = Subset(self, self.test_data, transform=self.transform)
return train_data, test_data
这里不直接返回self.train_data
与self.test_data
是因为规范问题,自我感觉这样返回的两个数据集因为没有getitem方法,会导致在访问的时候出问题,但是自己也没有尝试,最好还是用这种格式吧
Subset类定义
class Subset(Dataset):
def __init__(self, dataset, indices, transform=None):
self.dataset = dataset
self.indices = indices
self.transform = transform
def __len__(self):
return len(self.indices)
def __getitem__(self, index):
img_path = self.indices[index][0]
label = self.indices[index][1]
img = Image.open(img_path).convert('L')
if self.transform is not None:
img = self.transform(img)
# 需注意
label = int(label)
label = torch.tensor(label)
return img, label
比较简单,唯一需要注意的是
我在用这种该方法获取数据集的标签label时,会出现获取的不是tensor数据类型而是tuple数据类型,就导致会出一些问题,网上搜到的解决方法是,先将标签转化为int类型,然后在用torch.tensor进行类型转化,这样训练就可以不出错了
使用方法
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
])
root_dir = '../datasets/mnist_png_with_no_spilt'
my_dataset = MyDataset(root_dir, transform=transform)
train_dataset, test_dataset = my_dataset.spilt_dataset()