一、前言
- 图片增强
提高图片泛化度,包括 旋转、翻转、拉伸、色彩抖动等处理,需要根据具体图片类型来决定,比如,我做猫狗二分类,那么旋转、拉伸、翻转、抖动都可以,但是我如果做的是比较严谨的分类比如医学相关的,那么翻转、拉伸、色彩抖动就别整了或者参数调小点 - 归一化与标准化
<1>图片像素值统一除以255,归一化到 [0,1] 之间
<2>再将归一化的结果减去0.5,除以0.5,标准化到 [-1, 1] 之间 - 训练数据可以进行增强和归一化,但是预测数据只进行归一化即可,训练数据增强是为了让模型适应更多不确定性的环境,但是预测的时候就不要把图片转来转去拉来拉去给自己找麻烦了(狗头),这样也能保证输出结果稳定、唯一
二、预处理与增强
1.针对训练数据
- 单个图片的增强与加载
from PIL import Image
from torchvision import transforms
def get_transform_for_predict():
'''
图片数据转换
:return:
'''
return transforms.Compose([
transforms.Resize(size=(224, 224)), # 图片拉成 224*224
transforms.RandomHorizontalFlip(p=0.3), # 将三成图片水平翻转
transforms.RandomVerticalFlip(p=0.3), # 将三成图片垂直翻转
transforms.RandomPerspective(distortion_scale=0.3, p=0.3), # 将三成图片不规则拉伸,拉伸力度0.3
transforms.RandomRotation(degrees=(0, 180)), # 图片随机旋转,0-180度
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 图片随机抖动,四个对应的值分别是 亮度、对比度、饱和度、色调
transforms.ToTensor(), # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
])
if __name__ == '__main__':
# 加载图片
img_path = 'dataset/train/dog/dog.0.jpg'
img = Image.open(img_path).convert('RGB')
# 初始化数据增强方法
tsf = get_transform_for_predict()
# 图片增强
X = tsf(img)
# # show
# plt.imshow(X.permute(1, 2, 0).numpy())
# plt.show()
- 批量增强与加载
创建如下目录结构
datasets # 根路径
|
|----train # 训练集
|
|----cat # 样本目录1
| |
| |----01.jpg # 图片数据1
| |----02.jpg # 图片数据2
| |----...
|
|
|----dog # 样本目录2
|
|----01.dog # 图片数据1
|----02.dog # 图片数据2
|----...
# 加载后 cat、dog...等样本目录 会依次变成 0、1... 等标签,分别对应样本目录下的图片数据
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
def get_transform_for_train():
'''
图片数据转换
:return:
'''
return transforms.Compose([
transforms.Resize(size=(224, 224)), # 图片拉成 224*224
transforms.RandomHorizontalFlip(p=0.3), # 将三成图片水平翻转
transforms.RandomVerticalFlip(p=0.3), # 将三成图片垂直翻转
transforms.RandomPerspective(distortion_scale=0.3, p=0.3), # 将三成图片不规则拉伸,拉伸力度0.3
transforms.RandomRotation(degrees=(0, 180)), # 图片随机旋转,0-180度
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 图片随机抖动,四个抖动因子分别是 亮度、对比度、饱和度、色调
transforms.ToTensor(), # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
])
def split_data_to_train_and_valid(datasets, validation_split=0.1):
'''
将数据集分为训练集和验证集
:param datasets:
:param size:
:return:
'''
train_size = int((1-validation_split) * len(datasets)) # 训练集size
validation_size = len(datasets) - train_size # 验证集size
train_dataset, validation_dataset = random_split(datasets, [train_size, validation_size]) # 拆分
return train_dataset, validation_dataset
def get_data_iter(data_path, batch_size, validation_split):
'''
获取训练张量
:return:
'''
datasets = ImageFolder(
root=data_path, # 数据集路径
transform=get_transform_for_train() # 图片增强
)
# 拆分数据集和验证集
train_dataset, validation_dataset = split_data_to_train_and_valid(datasets, validation_split)
# 分别加入迭代器
train_iter = DataLoader(
train_dataset, # 训练集
batch_size=batch_size, # 批量大小
shuffle=True, # 是否乱序
num_workers=4 # 加载时使用的进程并发数
)
validation_iter = DataLoader(
validation_dataset, # 训练集
batch_size=batch_size, # 批量大小
shuffle=True, # 是否乱序
num_workers=4 # 加载时使用的进程并发数
)
return train_iter, validation_iter
if __name__ == '__main__':
data_path = r'E:\数据集\猫狗数据集\kaggle_Dog&Cat\train'
# 加载数据并拆分为训练集和验证集
train_iter, validation_iter = get_data_iter(
data_path=data_path, # 数据集位置
batch_size=64, # 批量大小
validation_split=0.3 # 三成数据作为验证集
)
# 到此为止就可以送入网络训练了,下面的是打印检查数据 ------------
# 打印每个batch的训练集和验证集长度
print(len(train_iter), len(validation_iter))
# 274 118
# 我这里准备了25000张图片
# 训练集每批数据量 = 25000 / 64 * (1 - 0.3) = 274
# 验证集每批数据量 = 25000 / 64 * 0.3 = 118
# 打开训练集的第一张图片康康
for index, batch_data in enumerate(train_iter):
# 打印索引、训练集尺寸、标签尺寸
print(index, batch_data[0].shape, batch_data[1].shape)
# 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
# 打印第一批数据的第一张图的标签和样本
print(batch_data[0][0], batch_data[1][0])
# 样本
# tensor([[[-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765],
# [-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765],
# [-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765],
# ...,
# [-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765],
# [-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765],
# [-0.9765, -0.9765, -0.9765, ..., -0.9765, -0.9765, -0.9765]]])
# 标签
# tensor(1)
for data in batch_data[0]:
data = data.permute(1, 2, 0) # 通道维度放到最后
plt.imshow(data.numpy()) # 转成numpy展示图片
plt.show()
break
break
# 打开验证集的第一张图片康康
for index, batch_data in enumerate(validation_iter):
# 打印索引、训练集尺寸、标签尺寸
print(index, batch_data[0].shape, batch_data[1].shape)
# 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
# 打印第一批数据的第一张图的标签和样本
print(batch_data[0][0], batch_data[1][0])
# 样本
# tensor([[[-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922],
# [-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922],
# [-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922],
# ...,
# [-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922],
# [-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922],
# [-0.9922, -0.9922, -0.9922, ..., -0.9922, -0.9922, -0.9922]]])
# 标签
# tensor(0)
for data in batch_data[0]:
data = data.permute(1, 2, 0) # 通道维度放到最后
plt.imshow(data.numpy()) # 转成numpy展示图片
plt.show()
break
break
2.针对预测数据
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
def get_transform_for_predict():
'''
图片数据转换
:return:
'''
return transforms.Compose([
transforms.Resize(size=(224, 224)), # 图片拉成 224*224
# transforms.RandomHorizontalFlip(p=0.3), # 将三成图片水平翻转
# transforms.RandomVerticalFlip(p=0.3), # 将三成图片垂直翻转
# transforms.RandomPerspective(distortion_scale=0.3, p=0.3), # 将三成图片不规则拉伸,拉伸力度0.3
# transforms.RandomRotation(degrees=(0, 180)), # 图片随机旋转,0-180度
# transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 图片随机抖动,四个对应的值分别是 亮度、对比度、饱和度、色调
transforms.ToTensor(), # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
])
if __name__ == '__main__':
# 加载图片
img_path = 'dataset/train/dog/狗砸.png'
img = Image.open(img_path).convert('RGB')
# 初始化数据增强方法
tsf = get_transform_for_train()
# 图片增强
X = tsf(img)
# # show
# plt.imshow(X.permute(1, 2, 0).numpy())
# plt.show()