#False,下载的是网络模型,默认参数
#vgg16_false = torchvision.models.vgg16(pretrained = False)
#True,下载的是网络模型,并且在数据集上面训练好的参数。
#vgg16_true = torchvision.models.vgg16(pretrained = True)
#学习在现有的网络进行修改。vgg16是将数据分为1000类。而数据集CIFAR10只有十类。
#1.给vgg16模in_feature=1000,out_feature = 10;
#2.直接修改,将最后一层改为out_feature = 10;
import torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16_true = torchvision.models.vgg16(pretrained = True)
print(vgg16_true)
train_data = torchvision.datasets.CIFAR10("a",train=True,transform=torchvision.transforms.ToTensor(),
download = True)
#在vgg16的classifier下加一层模型,名叫add_linear,module名,in_feature=1000,out_feature=10
vgg16_true.classifier.add_module('add_lnear',nn.Linear(1000,10))
print(vgg16_true)
#修改最后一行结构为out_feature=10
vgg16_false.classfier[6] = nn.Linear(4096,10)
print(vgg16_false)