【学习笔记】现有网络模型的使用及修改

主要内容:

        一. 如何加载pytorch提供的网络模型?

        二.如何修改网络模型?

打开pytorch官网:

PyTorch icon-default.png?t=N7T8https://pytorch.org/选择torchvision,版本选择0.9.0,打开torchvision.models,下面以VGG16模型(在ImageNet数据集中训练的)为例,

import torchvision


#上面下载量一百多个G, 太大了
#train_data=torchvision.datasets.ImageNet('./data_image_net',split='train',download=True,transform=torchvision.transforms.ToTensor())

vgg16_false = torchvision.models.vgg16(pretrained=False,progress=True)
vgg16_true=torchvision.models.vgg16(pretrained=True,progress=True)
"""
pretrained=False: 表示下载的网络模型(VGG16)参数未在任何数据集中进行训练,也就是说这些参数都是初始化参数,相当于前面写的
cifar10 model structure 网络架构 
progress=True: 显示下载到标准程序的进度条
pretrained=True: 表示下载的网络模型(VGG16)其中的参数已经在数据集(ImageNet)中训练好了,取得了不错的效果
"""
#print训练好的网络模型(架构)
print(vgg16_true)

输出结果:

 

 查看ImageNet数据集:

ImageNeticon-default.png?t=N7T8https://image-net.org/challenges/LSVRC/2012/index.php

如何利用现有网络通过改动它的结构,而避免写vgg16结构? 因为很多框架将vgg16当作前置网络结构(用来提取一些特殊的特征),然后在vgg16后面加一些结构,实现特殊的功能

下面有两种方法将vgg16模型通过添加一个线性层linear,使网络模型满足CIFAR10的网络模型结构

首先,看一下vgg16与CIFAR10的模型结构:

 

方法一:在上述代码下面添加下列代码并运行:

train_data=torchvision.datasets.CIFAR10('./datasets',train=True,transform=torchvision.transforms.ToTensor(),
                                        download=True)
#如何利用现有网络通过改动它的结构,避免写vgg16, 很多框架将vgg16当作前置网络结构(用来提取一些特殊的特征),然后在vgg16后面加一些结构,实现特殊的功能
#在vgg16中添加一层线性层
vgg16_true.add_module('add_linear',nn.Linear(in_features=1000, out_features=10))  #in_features=1000这里为什么是1000? 由vgg16最后线性层的输出决定
print(vgg16_true)

 运行结果为:

如何把add_linear加入到classifier中呢?

将上述代码vgg16.add_module(...)改为vgg16.classifier.add_module(...)即可,运行结果如下:

 方法二:通过改vgg16模型的输出,将输出1000类别改为10,具体如何用代码实现呢?

在vgg16_false上修改,首先print(vgg16_false),如下:

在上面代码下面添加下面几行代码:

print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(in_features=4096,out_features=10)
print(vgg16_false)

 运行结果如下:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值