import torchvision
from torch import nn
vgg_false = torchvision.models.vgg16(pretrained=False)
# 参数使用在别的数据集上训练好的初始化
vgg_true = torchvision.models.vgg16(pretrained=True)
# 在现有的网络结构进行一些修改,使它适用于自己的网络结构
train_data = torchvision.datasets.CIFAR10("./dataset", train=True, transform=torchvision.transforms.ToTensor(),
download=True)
# 在整个模型加
# vgg_true.add_module('add_linear', nn.Linear(1000, 10))
# print(vgg_true)
#在classifier中加
vgg_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg_true)
# 对模型进行修改
vgg_false.classifier[6] = nn.Linear(4096, 10)
print(vgg_false)
参考地址:https://www.bilibili.com/video/BV1hE411t7RN?p=25