深度学习模型

介绍:官方模型的调用、修改,以及模型的保存与读取

这里以vgg16为例子: vgg16_false=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) vgg16在torchvision.models模块下,weight表示权重(决定下载那个参数),如果为空,那么不下载,

如果想要修改默认下载位置:看https://blog.csdn.net/weixin_62769552/article/details/130295089,讲的很清楚

模型的保存有两中方法:(具体看代码)
第一种为保存网络结构以及参数:
但是在调用的时候要让程序能够访问到模型定义
第二种只保存参数:
以字典的形式保存参数,在导入参数的时候要先加载模型

代码展示:

import torch
import torchvision
from torch import nn
from torch.nn import ReLU
from torchvision.models import VGG16_Weights

train_data=torchvision.datasets.CIFAR10("../dataset",True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10("../dataset",False,transform=torchvision.transforms.ToTensor(),download=True)

#注意这里weights的定义以及,这里不再是使用pretrained选择是否进行下载权重而是
#由weights开定义的,如果weights为空,那么不下载参数,weights为谁就下载谁
vgg16_false=torchvision.models.vgg16()
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)


# vgg16_true.add_module("add_Linear",nn.Linear(1000,10)) #在最后添加add_Linear标签,添加nn.Linear(1000,10)
# vgg16_true.classifier[0].add_module("add_Linear",ReLU(inplace=True))#在classifier第零个添加ReLU(inplace=True)
# vgg16_true.classifier[6]=nn.Linear(4096,10)#同理:修改线性层

# print(vgg16_true)


#方式一
# 保存模型以及参数
torch.save(vgg16_true,"vgg16_method1.pth")
#加载模型
model_vgg16_method1=torch.load("./vgg16_method1.pth")
print(model_vgg16_method1)
#陷阱
#用自己的网络模型保存的模型要让程序能够访问到模型定义的方式
#可以是import导入,也可以复制到本页面


# 方式二(官方推荐)
#只保留参数
torch.save(vgg16_true.state_dict(),"vgg16_method2.pth")#以字典形式保存
#加载模型
vgg16_mothod2=vgg16_false
vgg16_mothod2.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_true.state_dict())#查看参数

  • 网络模型为什么这么大:因为参数太多了:深度学习模型通常包含大量的神经元和连接权重。这些参数是用来捕捉输入数据的特征和模式,从而实现模型对复杂问题的学习和预测能力。参数有权重、偏置、卷积核、循环神经网络(RNN)的参数。
  • 如果不进行预训练,weights为空, 有些版本为none
  • 修改网络模型,改成自己想要的模型,这是不是就是迁移学习呢!
  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值