Keras模型
- 保存整个模型为 h5 文件,内容包括
模型的结构
模型的权值
模型的配置,即我们通过compile编译模型的一些信息,如优化器,损失函数等
优化器的状态信息,我们可以接着之前的训练继续训练
保存整个模型
model.save('path_to_my_model.h5')
keras.models.load_model('path_to_my_model.h5')
加载整个模型
保存权重
model.save_weights('path_to_my_weights.h5')
new_model.load_weights('path_to_my_weights.h5')
Pytorch
保存的文件格式为pkl pth等,其实就是把权重优化器等信息序列化了。
保存权重
torch.save(model.state_dict(), PATH)
OR
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
},path)
加载模型
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()