上一篇笔记写了torch.save和torch.load来存储和读取训练好的model,这一篇是关于另一种saving和loading model的方法—用参数字典而不是整个训练好的model来加载model。
torch.nn.Module.load_state_dict
- 需要理解的定义
state_dict: 就是一个简单的Python 字典对象(dictionary object),用来存储参数(比如weights、biases),字典中存储model的layers和它对应的权重张量相对应。
(note: 部分layer如卷积层、线性层有参数所以存储在state_dict里,而有的layer如pooling layer没有参数可学习,所以在state_dict没有对应layer及参数)
# 举例
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(<