1 、官方模型引入
官方的模型保存在torchvision.models
中,可以去pytorch官网查看模型。eg.torchvision.models.vgg16(pretrained=True)
2、在模型后面添加module
(1).add_module在最末尾追加
vgg16_true.add_module("add_linear",nn.Linear(1000,10))
vgg16_true = torchvision.models.vgg16(pretrained=True)
gg16_true.add_module("add_linear",nn.Linear(1000,10))
结果对比:
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
(add_linear): Linear(in_features=1000, out_features=10, bias=True)
)
(2)在指定的序列内部追加module
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
追加后的,追加前的看上面代码。
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
(add_linear): Linear(in_features=1000, out_features=10, bias=True)
)
)
3、对模型进行修改
vgg16_true.classifier[6]=nn.Linear(4098,10)
对序列的操作是取数组下标
修改前后对比:
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4098, out_features=10, bias=True)
)
)
4、模型保存
模型的保存主要是,模型经过训练之后参数是不断调优的,这个时候要把调优后的模型保存下来。
(1)保存模型的参数
torch.save(model.state_dict(), '\parameter.pkl')#要保持的模型参数,保存位置
model = TheModelClass(...)
model.load_state_dict(torch.load('\parameter.pkl'))
(2)保存整个模型
torch.save(model, '\model.pkl')
model = torch.load('\model.pkl')