pytorch对模型的操作
导入
这里以vgg11的classifier层为例
导入模型:
vgg11 = torchvision.models.vgg11(pretrained=False)
print(vgg11)
其中classifier的输出如下
修改
主要的修改方式:
# 在某层中添加层
vgg11.classifier.add_module('new_linear', nn.Linear(1000, 10))
# 修改某层
vgg11.classifier[6] = nn.Linear(4096, 10)
print(vgg11)
修改后:
保存与读取
# 模型的保存与读取
# 方式1,保存模型和参数
torch.save(vgg11, 'vgg11_method1.pth')
# 在读取时需要保证原模型已经引入
model = torch.load('vgg11_method1.pth')
# 方式2,只保存模型参数,一个字典形式(官方推荐)
torch.save(vgg11.state_dict(), 'vgg11_method2.pth')
vgg11_new = torchvision.models.vgg11(pretrained=False)
vgg11_new.load_state_dict(torch.load('vgg11_method2.pth'))