深度学习框架_PyTorch_模型预训练_PyTorch使用预训练模型全流程(实例:CIFAR图像分类)

首先,这篇文章是接着下面这篇文章进行补充的,大家可以先看这篇文章如何手动设计一个分类网络:

深度学习框架_PyTorch_使用PyTorch编写卷积神经网络全流程(实例:CIFAR10图像分类)

一.预训练模型简介

卷积神经网络在图像识别,图像分割等领域取得了非常大的成功,出现了很多非常好的网络模型,比如ResNet、GoogLeNet等图像分类模型、Faster R-CNN、YOLO等目标检测模型。

在使用卷积网络解决我们自己遇到的图像识别任务时,通常先用这些优秀的模型作为基准,再针对自己所处理的识别问题的特点做有针对性的修改和设计,因为从0开始设计一个全新的网络结构并且取得非常好的效果是很困难的。

torchvision.models模块中给出了很多优秀的卷积网络的定义,比如VGG、ResNet、SqueezeNet、DenseNet、Inception、GoogLeNet、MobileNet等。这些模型都有对应的类,可以通过类的构造函数创建相应的模型。

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.googlenet()
mobilenet = models.mobilenet_v2()
resnet50_32x4d = models.resnet50_32x4d()

PyTorch的Model Zoo中还提供了这些模型在ImageNet等数据集上训练得到的权值参数。如果要使用这些预训练的参数,只需要在模型类的构造函数中指定参数pretrained值为True。在加载模型时,会从指定的网址下载模型参数(.pth文件)。

resnet50 = models.resnet50(pretrained =True)

如果本地已经下载了权值pth文件,也可以在创建模型时,指定pretrained=False(默认参数),然后从本地磁盘文件中加载状态词典文件,并加载到模型中。如下面代码所示:

resnet50 = models.resnet50()
#使用本地磁盘上的模型参数文件
state_dict = torch.load('resnet50.pth')
#把读入的模型参数加载到模型model中:
resnet50.load_state_dict(state_dict)

使用预训练模型时要注意对输入图像做指定的归一化操作,在ImageNet上训练的模型,大部分输入图像的大小是224 * 224 * 3,并且要转换为(3 * 224 * 224)形状的张量。图像的像素值做归一化,均值与标准差分别为mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225]。
接下来我们编写代码实现这个标准化操作:

normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])

二.使用预训练模型解决新的图像识别问题

面对一个新的图像识别任务,常见的做法是使用ImageNet上训练的模型做迁移学习:即使用已经训练好的模型,把最后的分类层(Softmax)改造为新的分类层,用新的图像数据做进一步训练,得到一个针对新任务的识别模型。

根据新任务中训练样本数量的多少,又可以分为以下三种迁移方式(以ResNet为例):

  1. 数据量非常少(比如每个类别只有几十张图像):用ResNet做为特征器,只用新数据训练Softmax层分类器。
  2. 数据量比较少(比如总数据量比较少,只有几千张图像):固定ResNet低层网络权值,用新数据训练高层权值,并且使用预训练的权值初始化高层权值。
  3. 数据量比较大:用预训练权值初始化整个网络,然后用新的数据微调全部参数。

下面我们还是以花卉识别为例,使用上面三种方式把ImageNet数据集上训练的ResNet50模型迁移到花卉识别任务上。

我们使用OxfordFlower102数据集,该数据集中有102种不同的花卉,每一种花卉有几十张图片。

下面我们编写代码实现训练集与测试集的数据加载器

train_dir = '...'
val_dir = '...'
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
train_set = datasets.ImageFolder(
	train_dir,
	trainsforms.Compose([
		transforms.PandomResizedCrop(224),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		normalize,
		]))

test_set = datasets.ImageFolder(
	val_dir,
	trainsforms.Compose([
		transforms.Resize(256),
		transforms.PandomResizedCrop(224),
		transforms.ToTensor(),
		normalize,
		]))
		
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=2)

接下来,我们构造一个ResNet50模型,并且使用本地磁盘上的预训练权值参数

net = models.resnet50(pretrained=False)
state_dict = torch.load('.../resnet50.pth')
net.load_state_dict(state_dict=state_dict)

我们再使用ResNet50作为特征提取器。在这种情况下,完全不需要训练模型各个层的参数,只需要把各个参数的属性requires_grad设置为False即可

for param in net.parameters():
	param.requires_grad = False

然后如果我们想要部分微调ResNet50。我们对ResNet50的高层参数做微调,因此,只需把低层参数的属性requires_grad设置为False即可。

下面的代码列出了网络的各个层参数名称:

for name,_ in net.named_parameters():
	print(name)

假设只需要训练layer4的权值,那么我们可以把layer1-layer3的参数的梯度属性设置为False,如下面的代码所示:

exclude_layers = ['layer1', 'layer2', 'layer3']#这些层不训练
for name, param in net.named_parameters():
	for layer in exclude_layer:
		if name.startswith(layer):
			param.requires_grad = False
			break

如果我们要对参数全部微调,这种情况下不用做任何特殊处理,net的各层参数的requires_grad属性默认为True。

接下来我们要改造分类层,原始的ResNet模型是用于对ImageNet图像做分类的,分类层有1000个输出节点。在OxfordFlower102数据集上,要把这个分类层改造为输出102个节点。这只需要把原始网络中的FC层替换成新的全连接层即可,如下代码所示:

num_calsses = 102
featureSize = net.fc.in_features
net.fc = nn.linear(featureSize, num_classes)

最后,我们开始训练网络。

  • 4
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Rocky Ding*

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值