18.1 以VGG16为例使用并修改现有model
import torch
import torchvision.transforms
from torch import nn
# train_data = torchvision.datasets.ImageNet("../data_image_net", split='train', download=True,
# transform=torchvision.transforms.ToTensor()) # 该数据集过大,放弃下载
vgg16_true = torchvision.models.vgg16(weights=None)
# 加载本地预训练权重文件,并把权重存储在字典中
pretrained_dict = torch.load(r'C:\Users\Avlon\.cache\torch\hub\checkpoints\vgg16-397923af.pth')
vgg16_true.load_state_dict(pretrained_dict) # 将下载的预训练权重加载到模型中
vgg16_default = torchvision.models.vgg16() # 返回一个随机初始化的网络模型
# print('ok')
train_data = torchvision.datasets.CIFAR10('../data', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10)) # 在 classifier 末尾增加线性层;将输出大小变成10
print(vgg16_true)
# print(vgg16_default)
vgg16_default.classifier[6] = nn.Linear(4096, 10) # 将 classifier 第7层修改为线性层,输出大小变成10
print(vgg16_default)
18.2 小结
基本操作很简单,不过因为Pytorch版本更新换代的关系,是否使用vgg16预训练模型——参数的设置已经与课程中不同。这里本人把问题想复杂了(8-11行),如官方文档所写
pretrained=False 改为空白(啥都不填),pretrained=True 改为 weights=‘DEFAULT’
# torchvision.models.vgg16() 创建一个无预训练权重的随机初始化模型;
# torchvision.models.vgg16(weights=None) 创建的模型无默认的权重文件,需手动加载;
# torchvision.models.vgg16(weights=‘DEFAULT’) 创建的模型默认将预训练权重文件下载并加载到模型中