先加载vgg16网络模型
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=True)
print(vgg16)
网络结构,可以看出,此网络是个分类模型,有1000个类别,而之前所用到的CIFAR10只有10个类别,怎么运用这个模型呢?🧐很简单,在最后加个线性层,将1000->10
💛 加载数据集
dataset = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(), download=True)
💛 添加线性层
vgg16.add_module('add_liner', nn.Linear(1000, 10))
💬或者也可以直接对模型进行修改,无需添加
vgg16.classifier[6] = nn.Linear(4096, 10)
print(vgg16)