先有网络模型的使用及修改

本文介绍了如何使用预训练的网络模型,如VGG16,在不同的数据集上进行微调。当`pretrained=True`时,模型带有在ImageNet上的权重;反之,则模型参数初始化。针对CIFAR10数据集,需要修改模型的最后一层,将其输出类别数改为10。两种方法包括直接添加新的线性层和修改现有线性层的输出特征数,以适应不同分类任务的需求。
摘要由CSDN通过智能技术生成

先有网络模型的使用

使用示例代码:

import torchvision
from torch import nn

# 加载网络

# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)

# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)

vgg16的使用有两个常用参数,分别是pretrainedprocess

  • pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
  • process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
    通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。

先有网络模型的修改(如何利用现有的网络去改动它的一个结构)

1.添加网络层

示例代码如下:

import torchvision
from torch import nn

# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构

# 1.添加网络层

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))

# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))

print("vgg16_true:\n",vgg16_true)

运行结果:
在这里插入图片描述
在这里插入图片描述

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

2.直接修改网络

示例代码如下:

import torchvision
from torch import nn

# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

# 按顺序对网络进行索引,修改最后的线性层 
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)

运行结果:
在这里插入图片描述
在这里插入图片描述

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

布兹学长

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值