1.DataLoader and Dataset
数据模块又可以细分为 4 个部分:
数据收集:样本和标签。
数据划分:训练集、验证集和测试集
数据读取:对应于PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
数据预处理:对应于 PyTorch 的 transforms
功能:Dataset 是抽象类,所有自定义的 Dataset 都需要继承该类,并且重写__getitem()方法和__len()方法 。__getitem()方法的作用是接收一个索引,返回索引对应的样本和标签,这是我们自己需要实现的逻辑。len()方法是返回所有样本的数量。
首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调用DatasetFetcher根据index获取数据。在DatasetFetcher里会调用Dataset的__getitem()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser。
2.transforms
# 设置训练集的数据增强和转化
train_transform = transforms.Compose([
transforms.Resize((32, 32)),# 缩放
transforms.RandomCrop(32, padding=4), #裁剪
transforms.ToTensor(), # 转为张量,同时归一化
transforms.Normalize(norm_mean, norm_std),# 标准化
])
设置验证集的数据增强和转化,不需要 RandomCrop
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
当我们需要多个transforms操作时,需要作为一个list放在transforms.Compose中。需要注意的是transforms.ToTensor()是把图片转换为张量,同时进行归一化操作,把每个通道 0~255 的值归一化为 0~1。在验证集的数据增强中,不再需要transforms.RandomCrop()操作。然后把这两个transform操作作为参数传给Dataset,在Dataset的__getitem__()方法中做图像增强。
对数据进行均值为 0,标准差为 1 的标准化,可以加快模型的收敛。
3.数据增强![](https://i-blog.csdnimg.cn/blog_migrate/dfc88333e1aa81e1b52cba3e9ca3de80.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/9da3220ec208e421c7f23591ab6dc19e.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/fa28e503ce3ab7176aaa108fa7b9dcd8.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/480961551f5ada08a4be04b9b1cedd2a.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/3f61403aa1323a3ee28d966b77689214.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/4f6b5a76602684f4d0fa7b735801f342.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/855daaf8826f57dc926a118ded8529df.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/29cb3d07f7d3bb26bb2d09cf0a1abb2e.png)
![](https://i-blog.csdnimg.cn/blog_migrate/0576b300f70718e781252d5f69901e91.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/f4653af2b61f2b308b456f462be3a75a.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/34ea93b1b390090e6ae73a025bf31524.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/c5a76dc5ead529d0fb55482eaf5a024c.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/766aabcc363dd91c61b8b25387260ff5.png)
![](https://i-blog.csdnimg.cn/blog_migrate/c7b6ac91826621bad7edc9f135bae782.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/3fbd2a35f24a854e4c1f94c18c65e83e.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/544b7151b6c979b06559302776f2b993.png)
4. code
数据集划分
# -*- coding: utf-8 -*-
"""
# @file name : 1_split_dataset.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir) # 用来创建多层目录(单层请用os.mkdir)
if __name__ == '__main__':
random.seed(1)
dataset_dir = os.path.join("G:\\", "hello", "data", "Cat_dog_data") # 路径拼接
split_dir = os.path.join("G:\\", "hello", "data", "cat_dog_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point - train_point,
img_count - valid_point))
dataset
# -*- coding: utf-8 -*-
"""
# @file name : dataset.py
# @author : yts3221@126.com
# @date : 2019-08-21 10:08:00
# @brief : 各数据集的Dataset定义
"""
import os
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
rmb_label = {"cat": 0, "dog": 1}
class CATDataset(Dataset):
def __init__(self, data_dir, transform=None): # 初始化
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"cat": 0, "dog": 1}
# data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.data_info = self.get_img_info(data_dir)
self.transform = transform
# 根据索引index返回图像及标签,即获取图像
def __getitem__(self, index):
# 通过self.data_info函数得到图像路径和标签
path_img, label = self.data_info[index]
# 通过Image.open得到img
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
# 查看数据长度,即样本的数量,数据集的数量
def __len__(self):
return len(self.data_info)
# 自定义的函数
# 用于获取路径和标签
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info