本文转自:https://zhuanlan.zhihu.com/p/89442276,作者解释得很棒,生怕作者删了文章,故copy过来,在此感谢作者!
模型保存
在 Pytorch 中一种模型保存和加载的方式如下:
# save
torch.save(model.state_dict(), PATH)
# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
可以看到模型保存的是 model.state_dict()
的返回对象。 model.state_dict()
的返回对象是一个 OrderDict
,它以键值对的形式包含模型中需要保存下来的参数,例如:
class MyModule(nn.Module):
def __init__(self, input_size, output_size):
super(MyModule, self).__init__()
self.lin = nn.Linear(input_size, output_size)
def forward(self, x):
return self.lin(x)
module = MyModule(4, 2)
print(module.state_dict())
输出结果:
模型中的参数就是线性层的 weight 和 bias.
Parameter 和 buffer
If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a change to update them.
模型中需要保存下来的参数包括两种:
- 一种是反向传播需要被optimizer更新的,称之为 parameter
- 一种是反向传播不需要被optimizer更新,称之为 buffer
第一种参数我们可以通过 model.parameters()
返回;第二种参数我们可以通过 model.buffers()
返回。因为我们的模型保存的是 state_dict
返回的 OrderDict
,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict
。
那么现在的问题是这两种参数如何创建呢,创建好了如何保存到OrderDict
呢?
第一种参数有两种方式:
- 我们可以直接将模型的成员变量(http://self.xxx) 通过
nn.Parameter()
创建,会自动注册到parameters中,可以通过model.parameters() 返回,并且这样创建的参数会自动保存到OrderDict中去; - 通过
nn.Parameter()
创建普通P