使用vgg
添加一层网络结构
import torchvision
from torch import nn
# 没训练的
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 下载训练好的(C:\Users\Administrator/.cache\torch\hub\checkpoints\vgg16-397923af.pth)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("ok")
print(vgg16_true)
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
实现
修改某一层
# 修改
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
实现