pytorch现提供的网络模型该怎么修改?该怎么使用?
现有网络模型的使用及修改
分类模型-网站:https://pytorch.org/vision/stable/models/vgg.html VGG16分类。最常用的是VGG16 VGG19。这是网络模型。
如果pretrained 为true,则表示它已经被预训练好,可以取得不错的效果。否则表示没有在任何数据集上进行训练。progress为true,则表示显示进度条。否则不显示。
ImageNet数据集很大。而且不能通过代码下载,必须手动下载。
所以就得使用之前的数据集或者只是用网络结构。
只使用网络结构👇
import torchvision.datasets
from torch import nn
# train_data = torchvision.datasets.ImageNet(root = "./dataset/imagenet" , split="train" ,
# transform=torchvision.transforms.ToTensor(),download=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)#当为false时,只是加载网络模型,参数是默认参数,不需要下载,相当于之前写的网络架构
vgg16_true = torchvision.models.vgg16(pretrained=True)#当为true时,下载对应参数是多少;这些模型的参数就是在数据集上训练好的
print("ok")
print("vgg16_false : " , vgg16_false)
print("vgg16_true : " , vgg16_true)
运行结果👇
VGG16_false网络架构👇
VGG16_true网络架构👇
像之前数据集CIFAR10只有10类,但此时VGG16分类模型有分1000类,作者提供了两种思路来改进网络模型。
- 直接在最后一层线性层,输出特征修改为10。
- 在最后一层线性层后再增加一层,使得in_features = 1000 , out_features = 10。
添加
怎么添加上述的第二种方法的线性层?一行代码搞定👇直接在VGG16_true整个模型后加
vgg16_true.add_module('add_linear', nn.Linear(1000,10))
#可以看到网络模型最后多了一层名为add_linear的线性层
或者若想要在classfifier子结构中添加,使用下面的代码搞定👇
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
可以看到最后的线性层添加到了classifier中。
修改
现在是另一种方法-修改,因为VGG16_true已经添加,所以现在用VGG16_false进行修改。
一行代码👇
vgg16_false.classifier[6] = nn.Linear(4096,10)
#看网络结构图可以知道最后分类模块角标是6,所以方括号内用6。
可以看到VGG16_false修改成功
为什么要用VGG16模型?
大多数框架是把VGG16当作前置的网络结构,来提取一些特殊特征,或之后加一些网络结构实现功能。所以为什么用它是因为使用它的框架数量多。
代码👇
import torchvision.datasets
from torch import nn
# train_data = torchvision.datasets.ImageNet(root = "./dataset/imagenet" , split="train" ,
# transform=torchvision.transforms.ToTensor(),download=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)#当为false时,只是加载网络模型,参数是默认参数,不需要下载,相当于之前写的网络架构
vgg16_true = torchvision.models.vgg16(pretrained=True)#当为true时,下载对应参数是多少;这些模型的参数就是在数据集上训练好的
print("ok")
# print("vgg16_false : " , vgg16_false)
print('-'*111)
print("vgg16_true : " , vgg16_true)
#添加
#方法1
# vgg16_true.add_module(name = "add_linear", module = nn.Linear(1000,10))
#方法2
# vgg16_true.classifier.add_module(name = "add_linear" , module = nn.Linear(1000,10))
# print(F"vgg16_true:{vgg16_true}")
#修改
vgg16_false.classifier[6]= nn.Linear(in_features = 4096 , out_features = 10)
print(F"vgg16_false = {vgg16_false}")
vgg16_false.classifier[6]= nn.Linear(in_features = 4096 , out_features = 10)
print(F"vgg16_false = {vgg16_false}")