- 以vgg16为例,先加载模型:参数pretrain为False表示不需要去下载参数,使用默认参数;为True表示需要去下载参数,已经在数据集上训练好的,效果会更好。
vgg16_false=torchvision.models.vgg16(pretrained=False) #False表示不需要去下载参数,使用默认参数
vgg16_true=torchvision.models.vgg16(pretrained=True) #True表示需要去下载参数,已经在数据集上训练好的
print(vgg16_true) #查看模型结构
- 如何添加模型结构?
1.在最后一层添加:
vgg16_true.add_module("add_linear",nn.Linear(1000,10)) #在最后添加
print(vgg16_true)
2.在模型内部添加:
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))#在中间添加
print(vgg16_true)
- 如何修改模型结构?
print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(4096,1000)
print(vgg16_false)