Pytorch模型添加、删除、修改网络层
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=True) # 加载预训练网络模型
## 添加层
vgg16.add_module("add_linear_end", nn.Linear(1000, 10)) # 在整个module外面增加一个Linear (1)
vgg16.classifier.add_module("add_linear", nn.Linear(1000, 10)) # (2)
## 修改层
# 方式一:知道输入特征维度
vgg16.classifier[6] = nn.Linear(4096, 10) # (3)
# ⭐方式二:不用自己查,直接代码获得输入特征维度
num_fc = vgg16.classifier[6].in_features #读取输入特征的维度
vgg16.classifier[6] = nn.Linear(num_fc,2) #修改最后一层的输出维度,即分类数 (4)
## 删除层
del vgg16.classifier[6] # (5)
del vgg16.classifier[6] # (6)
输出截图如下
(1)
(2)
(3)修改了分类器输出特征数
(4)实现的功能和 3 一样,但是这种做法更具有普适性
(5)
(6)