神经网络——网络模型

1.现有网络模型的使用及修改

1.1讲解VGG16网络模型:

import torchvision

vgg16_false=torchvision.models.vgg16(weights =None)
vgg16_true=torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)

在这里插入图片描述

ImageNet数据集太大了,仅训练集就有147.9g。

1.2改变现有网络的参数:

  • 在现有模型中添加模型
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

在这里插入图片描述

  • 添加模型至classifier中
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

在这里插入图片描述

  • 修改模型
print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

在这里插入图片描述

2.网络模型的保存与读取

2.1模型的保存与读取方法1:

  • torch.save(实例, 保存名称)——model_save.py
  • torch.load(实例, 保存名称)——model_load.py

方法1:保存了模型结构+模型参数

#保存方式1
torch.save(vgg16,"vgg16_method1.pth")
#读取方式1:加载模型
model=torch.load("vgg16_method1.pth")#D:\py_code5\XTD\XTDProject_3\vgg16_method1.pth
print(model)

2.2模型的保存与读取方法2:

  • torch.save(实例.state_dict(), 保存名称) ——model_save.py
  • torch.load(实例.state_dict(), 保存名称)——model_load.py

方法2保存的是:模型参数(官方推荐),vgg16的网络模型状态保存为字典格式。不保存结构。

#保存方式2
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#读取方式2,加载模型
model=torch.load("vgg16_method2.pth")

在这里插入图片描述

是字典格式,需要还原:

#方式二,加载模型
vgg16=torchvision.models.vgg16(weights = None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model=torch.load("vgg16_method2.pth")
print(vgg16)

即可。

方式二的数据大小要小一点
在这里插入图片描述

2.3方法1的陷阱:

用方法1的时候一定要保证读取模型的文件里有定义该模型的类!

#陷阱——model_save.py
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=3)

    def forward(self, x):
        x=self.conv1(x)
        return x

tudui=Tudui()
torch.save(tudui,"tudui_method1.pth")
#陷阱——model_load.py
model=torch.load("tudui_method1.pth")
print(model)

在这里插入图片描述

需要让模型能访问到定义的class

  • 方法一:将class的定义放入model_load.py中
#陷阱
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=3)

    def forward(self, x):
        x=self.conv1(x)
        return x

model=torch.load("tudui_method1.pth")
print(model)
  • 方法二:引入model_save
from model_save import *

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值