自带模型的增、改
import torchvision
from torch import nn
# 加载vgg16网络模型,pretrained 是否使用优质网络的参数,并不是权重参数
vgg16_f = torchvision.models.vgg16(pretrained=False)
# 加载vgg16网络模型,pretrained 是否使用优质参数
vgg16_t = torchvision.models.vgg16(pretrained=True)
print(vgg16_t)
train_data = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor)
# 增加一层 网络 当前为全连接层
vgg16_f.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_f)
# 改层
vgg16_t.classifier[6] = nn.Linear(4096, 10)
print(vgg16_t)