一、使用
拉取网络模型代码:
import torchvision
from torchvision.models import VGG16_Weights
# 调用vgg16的预训练模型-此时由于参数经过了预训练参数数据就需要进行下载
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
# 调用vgg16未经过预训练的原始模型
vgg16_false = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES)
print(vgg16_true)
拉取过程:
结果:
二、修改
2.1添加线性层
vgg16_true.add_module("add_linear",torch.nn.Linear(1000,10))
结果:
2.2修改
import torch.nn
import torchvision
from torchvision.models import VGG16_Weights
# 调用vgg16的预训练模型-此时由于参数经过了预训练参数数据就需要进行下载
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
# 调用vgg16未经过预训练的原始模型
vgg16_false = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES)
# 修改分类器里面的具体层-7层
vgg16_true.classifier[6] = torch.nn.Linear(1000,10)
# 修改序列化里面的具体层-31层
vgg16_true.features[30] = torch.nn.Linear(1000,10)
print(vgg16_true)
结果:
2.3删除
代码:
# 删除分类器中的第6层
vgg16_true.classifier[5] = nn.Sequential()
结果: