pytorch学习-模型处理

1 什么是模型

模型是神经网络训练优化后得到的结果,包含了神经网络骨架及学习得到的参数。


2 模型处理包括哪些操作

网络模型库、自定义模型、预训练模型的加载和模型保存、多GPU训练网络保存与加载、模型训练和测试的两种模式


2.1 网络模型库torchvision.models

torchvision.models库提供了众多经典的网络结构与预训练模型,例如VGG、ResNet和Inception等,利用这些模型可以快速搭建物体检测网络,不需要逐层手动实现。torchvision包与PyTorch相独立,需要通过pip指令进行安装,如下:

pip install torchvision

以VGG模型为例,在torchvision.models中,VGG模型的特征层与分类层分别用vgg.features与vgg.classifier来表示,每个部分是一个nn.Sequential结构,可以方便地使用与修改。


VGG16的特征层包括13个卷积、13个激活函数ReLU、5个池化,一共31层
VGG16的分类层包括3个全连接、2个ReLU、2个Dropout,一共7层

from torchvision import models
vgg = models.vgg16()
print(vgg.features)
print(vgg.classifier)

在这里插入图片描述

在这里插入图片描述

2.2 自定义模型

参考 神经网络工具箱torch.nn


2.3 加载预训练模型

为什么要进行预训练模型加载
对于计算机视觉的任务,包括物体检测,我们通常很难拿到很大的数据集,在这种情况下重新训练一个新的模型是比较复杂的,并且不容易调整,因此,Fine-tune(微调)是一个常用的选择。
什么是Fine-tune
所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型,在自己的数据集上训练自己的模型。

加载预训练模型的两种方法:
第一种使用torchvision.models中自带的预训练模型

from torchvision import models
vgg = models.vgg16(pretrained=True)

第二种使用本地预训练模型(或训练过的模型)

利用load_state_dict,遍历预训练模型的关键字,如果出现在了VGG中,则加载预训练参数

import torch
from torchvision import models
vgg = models.vgg16()
state_dict = torch.load("your model path")
# 利用load_state_dict,遍历预训练模型的关键字,如果出现在了VGG中,则加载预训练参数
vgg.load_state_dict({k:v for k,v in state_dict.items() if k in vgg.state_dict()})

或者

vgg = models.vgg16()
vgg_dict = vgg.state_dict()
pretrained_dict = torch.load("your model path")
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(vgg_dict[k]) == np.shape(v)}
vgg_dict.update(pretrained_dict)
vgg.load_state_dict(vgg_dict) 

2.4 模型保存与加载

PyTorch 中保存模型的方式有许多种:

# 保存整个网络
torch.save(model, PATH) 
# 保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(),PATH)
# 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(),
            'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
            PATH)

同样的,PyTorch 中读取模型参数的方式也有许多种:

# 读取整个网络
model = torch.load(PATH)

# 读取 Checkpoint 中的网络参数
model.load_state_dict(torch.load(PATH))

# 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载:
# 1. 利用字典的 update 方法进行加载
Checkpoint = torch.load(Path)
model_dict = model.state_dict()
model_dict.update(Checkpoint)
model.load_state_dict(model_dict)
# 2. 利用 load_state_dict() 的 strict 参数进行部分加载
model.load_state_dict(torch.load(PATH), strict=False)

2.5 多GPU训练网络保存与加载

指定多卡训练的模型就不是原来模型的类型了,而是并行化后的模型:
由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module
可以打印出来看一下


并行化后的模型参数必须加载到并行化的模型中,没并行化的参数要加载到没并行化的模型中,不然会出bug。

模型并行化后,保存没有并行化的模型并加载

model = DefinedNetwork()
torch.save(model.module.state_dict(), 'model_name.pth') # 保存

model.load_state_dict(torch.load(PATH))  # 加载

2.6 模型训练和测试的两种模式

model.train()和model.eval()分别在训练和测试中都要写,它们的作用如下:
(1) model.train()
启用BatchNormalization和 Dropout,将BatchNormalization和Dropout置为True
(2) model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False
注意:在训练模块中千万不要忘了写model.train();在评估(或测试)模块千万不要忘了写model.eval()

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值