小土堆-pytorch框架学习-P25-现有网络模型修改&使用

pytorch现提供的网络模型该怎么修改?该怎么使用?

现有网络模型的使用及修改

分类模型-网站:https://pytorch.org/vision/stable/models/vgg.html VGG16分类。最常用的是VGG16 VGG19。这是网络模型。

如果pretrained 为true,则表示它已经被预训练好,可以取得不错的效果。否则表示没有在任何数据集上进行训练。progress为true,则表示显示进度条。否则不显示。

ImageNet数据集很大。而且不能通过代码下载,必须手动下载。

所以就得使用之前的数据集或者只是用网络结构。

只使用网络结构👇

import torchvision.datasets
from torch import nn

# train_data = torchvision.datasets.ImageNet(root = "./dataset/imagenet" , split="train" ,
#                                            transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false = torchvision.models.vgg16(pretrained=False)#当为false时,只是加载网络模型,参数是默认参数,不需要下载,相当于之前写的网络架构
vgg16_true = torchvision.models.vgg16(pretrained=True)#当为true时,下载对应参数是多少;这些模型的参数就是在数据集上训练好的
print("ok")
print("vgg16_false : " , vgg16_false)

print("vgg16_true : " , vgg16_true)

运行结果👇

VGG16_false网络架构👇

image-20230706180022737

image-20230706180031689

VGG16_true网络架构👇

image-20230706180100954

image-20230706180112764

像之前数据集CIFAR10只有10类,但此时VGG16分类模型有分1000类,作者提供了两种思路来改进网络模型。

  1. 直接在最后一层线性层,输出特征修改为10。
  2. 在最后一层线性层后再增加一层,使得in_features = 1000 , out_features = 10。

添加

怎么添加上述的第二种方法的线性层?一行代码搞定👇直接在VGG16_true整个模型后加

vgg16_true.add_module('add_linear', nn.Linear(1000,10))
#可以看到网络模型最后多了一层名为add_linear的线性层

image-20230706180451481

或者若想要在classfifier子结构中添加,使用下面的代码搞定👇

vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))

可以看到最后的线性层添加到了classifier中。

image-20230706180519386

修改

现在是另一种方法-修改,因为VGG16_true已经添加,所以现在用VGG16_false进行修改。

一行代码👇

vgg16_false.classifier[6] = nn.Linear(4096,10)
#看网络结构图可以知道最后分类模块角标是6,所以方括号内用6。

可以看到VGG16_false修改成功

image-20230706180748055


为什么要用VGG16模型?

大多数框架是把VGG16当作前置的网络结构,来提取一些特殊特征,或之后加一些网络结构实现功能。所以为什么用它是因为使用它的框架数量多。

代码👇

import torchvision.datasets
from torch import nn

# train_data = torchvision.datasets.ImageNet(root = "./dataset/imagenet" , split="train" ,
#                                            transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false = torchvision.models.vgg16(pretrained=False)#当为false时,只是加载网络模型,参数是默认参数,不需要下载,相当于之前写的网络架构
vgg16_true = torchvision.models.vgg16(pretrained=True)#当为true时,下载对应参数是多少;这些模型的参数就是在数据集上训练好的
print("ok")
# print("vgg16_false : " , vgg16_false)
print('-'*111)
print("vgg16_true : " , vgg16_true)
#添加
#方法1
# vgg16_true.add_module(name = "add_linear",  module = nn.Linear(1000,10))

#方法2
# vgg16_true.classifier.add_module(name = "add_linear" , module = nn.Linear(1000,10))

# print(F"vgg16_true:{vgg16_true}")

#修改
vgg16_false.classifier[6]= nn.Linear(in_features = 4096 , out_features = 10)

print(F"vgg16_false = {vgg16_false}")

vgg16_false.classifier[6]= nn.Linear(in_features = 4096 , out_features = 10)

print(F"vgg16_false = {vgg16_false}")
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HelpFireCode

随缘惜缘不攀缘。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值