现有网络模型的使用和修改

1、参数pretrained 为false和true时的区别

  • vgg16_false = torchvision.models.vgg16(pretrained = False)
    #False,下载的是网络模型,默认参数

  • vgg16_true = torchvision.models.vgg16(pretrained = True)
    #True,下载的是网络模型,并且在数据集上面训练好的参数。

pretrained=False时,只是加载网络模型,把神经网络的代码加载了进来,其中的参数都是默认的参数,不需要下载。

pretrained=True时,它就要去从网络中下载,比如说卷积层对应的参数时多少,池化层对应的参数时多少等。这些参数都是在 ImageNet 数据集中训练好的。

2、现有网络模型的使用与修改

如何利用现有的网络,改变网络的框架来符合我们的需求?
我们之前使用的数据集CIFAR10只分为10个类别
如果我们加载了vgg16模型,我们如何应用这个网络模型呢?

有两个方法:
(1)将 (6): Linear(in_features=4096, out_features=1000, bias=True)改为Linear(in_features=4096, out_features=10, bias=True)
(2)再加一个线性层 Linear(in_features=1000, out_features=10, bias=True)

import torchvision



# 利用现有的网络改动
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)

train_data = torchvision.datasets.CIFAR10('./dataset_ts', train=True, transform=torchvision.transforms.ToTensor(),download=True)

# 在vgg16的classifier下加一层模型,名叫add_linear,module名,in_feature=1000,out_feature=10
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

print(vgg16_false)
# 修改最后一行结构为out_feature=10
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

参考:

土堆视频
https://blog.csdn.net/m0_51816252/article/details/125125715
https://blog.csdn.net/Crystalxxtt/article/details/124933634

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值