TensorFlow2 保存训练完成的网络,在其它程序中使用。
今天碰到个新问题:
TF2的网络在训练完成后,我需要保存下来在其它的测试数据上进行测试。
在新的文件中,通过装载保存的网络,获取训练完成的网络的所有信息。
然后将新数据一一的进行预测,最终输出预测精度。
官方解决办法
还是先给出相关的官方文档,
官方文档:https://tensorflow.google.cn/tutorials/distribute/save_and_load?hl=zh-cn
官方核心说法
如果已经有一个简单的模型可供使用,那么有两组可用的 API可以用来保存和装载网络:
如果你定义网络时继承的是tf.keras.Model,最好用高级方法。
- 高级 model.save 和 tf.keras.models.load_model
- 低级 tf.saved_model.save 和 tf.saved_model.load
高级方法的保存和读取
高级方法使用很简单,model就是你定义好的网络,最好能compile编译一下,否则在装载时会有告警。
编译后的网络,直接用
model.save(savepath)
保存就好,这里的savepath是你自己的保存路径字符串。
我的保存路径在当前目录下的一个叫