mport torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true) #可以看最后一层Linear(in_features=4096, out_features=1000, bias=True
#如何运用到只有十个类别的数据?
train_data = torchvision.datasets.CIFAR10('./dataset',train=True,transform=torchvision.transforms.ToTensor())
vgg16_true.add_module('add_linear',nn.Linear(1000,10)) #整个vgg16后面加
print(vgg16_true)
#如果是在中间层插入呢?比如叫classifier的module后面
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)
#不插入 直接修改呢?
# print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)
torchvision.models现有模型的基本使用
最新推荐文章于 2024-06-08 11:27:31 发布