总体流程
- 数据预处理部分:
- 数据增强:torchvision中transforms模块自带功能,当原始数据不够多的时候通过Data Augmentation(数据增强)使得图片的数量变多。CV中常用数据增强的方法有:对图片进行翻转、放大、缩小
- 数据预处理
- DataLoader模块直接读取batch数据
- 网络模块设置:
- 加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
- 需要注意的是别人训练好的任务跟我们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成自己需要的任务
- 训练时可以全部重头训练,也可以只训练最后任务的层,因为前几层都是做特征提取的,本质任务目标是一致的
- 网络模型保存与测试
- 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
- 读取模型进行实际测试
数据预处理
- 数据增强:
其中主要包含几个步骤:
- 图片以中心裁剪成224*224的,以提供给模型输入
- 用随机翻转做模型增强、亮度、对比度、转灰度
- 转换成tensor格式
- 为了迁移学习更好,需要与之前训练模型样本imagenet的处理一样,减去均值,除以方差
- 训练集与验证集需要选用相同的预处理方法。
from torchvision import transforms
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(224),#从中心开始裁剪,模型输入要求
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
- 数据读取:
- 通过ImageFolder读取数据
- batch 数据制造
from torchvision import models, datasets
batch_size = 8
# 第一个参数是图片的路径,第二个参数是转换的方法
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
迁移学习
通过迁移学习,在模型训练阶段不使用随机初始化模型而使用已经训练好的模型参数提取出来进行训练。迁移学习的两种方法
- 用训练好模型的参数作为初始化,再训练整个网络。
- 只改动最后一层全连接层,只重新训练全连接层
""" Resnet152
"""
# 下载预训练的模型
model_ft = models.resnet152(pretrained=use_pretrained)
# 模型是参数否需要冻住
set_parameter_requires_grad(model_ft, feature_extract)
# 改变最后一层全连接层
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102),
nn.LogSoftmax(dim=1))
input_size = 224