序列化与反序列化
模型的保存与加载也称序列化与反序列化
模型在内存中是以对象的形式存储的,而在硬盘中是以二进制序列保存的
序列化:是指将内存当中的某一个对象以二进制序列的形式存储到硬盘中,就可以长久的存储。
反序列化:将硬盘中的二进制数反序列化的放到内存中,得到对象,这样就可以使用模型了。
对应pytorch中的函数:
- torch.save
主要参数:
- obj:对象(模型、张量、parameters、dict 等等)
- f:输出路径(指定一个硬盘中的路径去保存)
模型保存有两种方法:
法1:保存整个Module
torch.save(net, path)
法2:保存模型参数
state_dict = net.state_dict()
torch.save(state_dict, path)
比如:
net = LeNet2(classes=2019)
# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict(