任务目标
实现对Pokemon数据集的预处理;(数据集在文末提供)
任务思路
了解数据集结构;
训练集、验证集与测试集的预划分;
数据集的加载。
实现过程
了解数据集结构
数据集文件夹结构
pokemon下属的每个子文件夹名称即为其下所有图片的label,如下图:
pokemon
├─bulbasaur
│ 00000000.png
│...
├─charmander
│ 00000000.png
│...
├─mewtwo
│ 00000000.jpg
│...
├─pikachu
│ 00000000.jpg
│...
└─squirtle
00000000.png
图片总数为1168,其中
各个类别图片数:
皮卡超(pikachu):234;超梦(mewtwo):239;杰尼龟(squirtle):223;
小火龙(charmander):238;妙蛙种子(bulbasaur):234;
图片格式:
.png:506;.jpg:657;.jpeg:4;.gif:1;
训练集、验证集与测试集的预划分
train | val | test |
60% | 20% | 20% |
数据集的加载
# 导包
import csv
import glob
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset # 自定义数据集的母类
from torch.utils.data import DataLoader
from torchvision import transforms
搭建基本框架
class Pokemon(Dataset):
def __init__(self, root, resize): # 初始化函数
super(Pokemon, self).__init__()
self.root = root # 数据集文件夹/pokemon所在路径
self.resize = resize
def __len__(self): # 获取总样本数量的函数
pass
def __getitem__(self, idx): # 返回当前idx样本数据及label的函数
pass
设置样本与标签的映射
将各图片与其所属文件夹名称作的label相对应
def __init__(self, root, resize): # 初始化函数
super(Pokemon, self).__init__()
self.root = root # 数据集文件夹/pokemon所在路径
self.resize = resize
self.name2label = {} # 创建形如{"name":label}的字典
for name in sorted(os.listdir(os.path.join(root))): # 将/pokemon下的文件夹名称排序后依次读入
if not os.path.isdir(os.path.join(root, name)): # 忽略除文件夹外的其他文件
continue
self.name2label[name] = len(self.name2label.keys()) # 将文件夹名称与自然数对应
print(self.name2label) # 1168, ['pokemon\\bulbasaur\\00000000.png'...]
创建.csv文件
将所有图片样本的路径与其标签以形如image_path(不直接加载图片本身,防止爆内存), label对象的形式保存到一个.csv文件中
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)): # 判断.csv文件是否存在
images = []
# 将所有文件夹下的各种格式的图片以 'pokemon\\bulbasaur\\00000000.png' 的形式保存images列表中
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
images += glob.glob(os.path.join(self.root, name, '*.gif'))
print(len(images), images) # 1168, ['pokemon\\bulbasaur\\00000000.png'...]
random.shuffle(images) # 将图片顺序随机打乱
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
# 将image列表的信息写入.csv文件中
for img in images:
name = img.split(os.sep)[-2] # 取 'pokemon\\bulbasaur\\00000000.png' 的倒数第二项作为name
label = self.name2label[name] # 通过name获得图片的label
writer.writerow([img, label]) # 以 'pokemon\\bulbasaur\\00000000.png', 0 为一行的形式写入.csv文件中
print('writen into csv file:', filename)
读取样本与标签
从.csv文件中读取刚刚保存的样本与标签,并将其添加到image和label列表中
# 接上文
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row # 读取.csv文件的一行
label = int(label) # 将label转为int类型
images.append(img)
labels.append(label)
assert len(images) == len(labels) # 保证样本与标签数目一致,一一对应
return images, labels
划分训练集、验证集与测试集
依照初始化函数的mode参数划分训练集、验证集与测试集
def __init__(self, root, resize, mode): # 初始化函数
super(Pokemon, self).__init__()
self.root = root # 数据集文件夹/pokemon所在路径
self.resize = resize
self.name2label = {} # 创建形如{"name":label}的字典
for name in sorted(os.listdir(os.path.join(root))): # 将/pokemon下的文件夹名称排序后依次读入
if not os.path.isdir(os.path.join(root, name)): # 忽略除文件夹外的其他文件
continue
self.name2label[name] = len(self.name2label.keys()) # 将文件夹名称与数字对应
# print(self.name2label) # 1168, ['pokemon\\bulbasaur\\00000000.png'...]
# image, label
self.images, self.labels = self.load_csv('image.csv') # 将上文的images和labels读进来
# 将全部图片按照预处理的比例划分
if mode == 'train': # 60% = 0 ~ 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20% = 60% ~ 80%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # elif mode == 'test': # 20% = 80% ~ 100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
数据读入
将数据通过__getitem__()读入,并对数据进行简单的预处理
def __len__(self):
return len(self.images) # 返回所选集的样本数量
def __getitem__(self, idx):
# idx:[0 ~ len(images)]
# self.images, self.labels
# img: 'pokemon\\squirtle\\00000205.jpg'
# label: 0
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # 将path所指的图片读入
设置图片形状
将图片形状统一,方便数据的读入和统一操作
# 接上文
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
# 放大1.25倍是避免之后旋转裁剪操作导致图片出现黑边
类型转换
将处理好的数据转换成Tensor数据类型,方便后期使用pytorch框架进行训练
# 接上文
transforms.ToTensor(),
数据扩充
通过对原有图片的旋转与裁剪,来达到增加数据集规模的效果
# 接上文
transforms.RandomRotation(15), # 将图片随机旋转15°(旋转角度不能过大)
transforms.CenterCrop(self.resize), # 将旋转后的图片中心裁剪
数据归一化
通过x = (x - mean) / std操作将图片像素值控制在0~1之间,让训练更加稳定,loss收敛更快
# 接上文
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
# mean和std的数据来自ImageNet的统计数据,对于图像分类问题具有较好的普适性
外部链接
Pokemon数据集:
链接: https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw
提取码:dsxl
文章项目链接:
链接:https://pan.baidu.com/s/1qg1tw4SBCJaZSUgynz_jwQ
提取码:indf