介绍:官方模型的调用、修改,以及模型的保存与读取
这里以vgg16为例子: vgg16_false=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) vgg16在torchvision.models模块下,weight表示权重(决定下载那个参数),如果为空,那么不下载,
如果想要修改默认下载位置:看https://blog.csdn.net/weixin_62769552/article/details/130295089,讲的很清楚
模型的保存有两中方法:(具体看代码)
第一种为保存网络结构以及参数:
但是在调用的时候要让程序能够访问到模型定义
第二种只保存参数:
以字典的形式保存参数,在导入参数的时候要先加载模型
代码展示:
import torch
import torchvision
from torch import nn
from torch.nn import ReLU
from torchvision.models import VGG16_Weights
train_data=torchvision.datasets.CIFAR10("../dataset",True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10("../dataset",False,transform=torchvision.transforms.ToTensor(),download=True)
#注意这里weights的定义以及,这里不再是使用pretrained选择是否进行下载权重而是
#由weights开定义的,如果weights为空,那么不下载参数,weights为谁就下载谁
vgg16_false=torchvision.models.vgg16()
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
# vgg16_true.add_module("add_Linear",nn.Linear(1000,10)) #在最后添加add_Linear标签,添加nn.Linear(1000,10)
# vgg16_true.classifier[0].add_module("add_Linear",ReLU(inplace=True))#在classifier第零个添加ReLU(inplace=True)
# vgg16_true.classifier[6]=nn.Linear(4096,10)#同理:修改线性层
# print(vgg16_true)
#方式一
# 保存模型以及参数
torch.save(vgg16_true,"vgg16_method1.pth")
#加载模型
model_vgg16_method1=torch.load("./vgg16_method1.pth")
print(model_vgg16_method1)
#陷阱
#用自己的网络模型保存的模型要让程序能够访问到模型定义的方式
#可以是import导入,也可以复制到本页面
# 方式二(官方推荐)
#只保留参数
torch.save(vgg16_true.state_dict(),"vgg16_method2.pth")#以字典形式保存
#加载模型
vgg16_mothod2=vgg16_false
vgg16_mothod2.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_true.state_dict())#查看参数
- 网络模型为什么这么大:因为参数太多了:深度学习模型通常包含大量的神经元和连接权重。这些参数是用来捕捉输入数据的特征和模式,从而实现模型对复杂问题的学习和预测能力。参数有权重、偏置、卷积核、循环神经网络(RNN)的参数。
- 如果不进行预训练,weights为空, 有些版本为none
- 修改网络模型,改成自己想要的模型,这是不是就是迁移学习呢!