pretrained(bool):如果是True,说明这个模型在ImageNet数据集上已经训练好了。
progress(bool):如果为True,显示一个下载的进度条。
import torch
import torchvision
import torchvision.models as models
# train_data = torchvision.datasets.ImageNet("../data_image_net",split='train', download=True,transform=torchvision.transforms.ToTensor())
from torch import nn
vgg16_false = models.vgg16(pretrained=False) #参数是默认的参数,是初始化的
vgg16_true = models.vgg16(pretrained=True) # 要去下载在ImageNet中训练好的参数
print("ok")
print(vgg16_true)
train_data = torchvision.datasets.CIFAR10('../data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)