对网络进行修改,首先需要打开网络,查看网络有哪些结构。然后在想要修改的地方进行修改或添加。
如上修改最后一层线性层,改为输入4096,输出10.
vgg16_false.classifier[6] =nn.Linear(4096, 10)
如果在Vgg中添加一层,
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))#在VGG中添加功能,
import torchvision.datasets
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))#在VGG中添加功能,如果想在Vgg下子程序加,vgg16_true.classifier.add_module()
print(vgg16_true)
print(vgg16_false)
vgg16_false.classifier[6] =nn.Linear(4096, 10)
print(vgg16_false)