步骤:
- 配置环境(转换环境与训练环境相同可省略)
- 转换代码运行
1. 配置环境
主要指以下三个的版本需相同:
- Python
- Keras
- TensorFlow
查看环境命令请看下面例子
例如我朋友的环境是:
因此: 我负责转换时,TensorFlow
、Python
、Keras
也要与之相同。这样可以避免踩到很多不必要的坑中。
环境安装方法:
- 安装参考:在windows中安装tensorflow_gpu==1.15.0的流程。经测试不一定需要安装gpu版的tensorflow,cpu的也可以。安装方法是一样的,所以可以参考这篇博客。
2. 转换代码
因为本作者拿到的模型是在TensorFlow 2.x 版本下训练的,而转换代码的源码是在TensorFlow 1.x 版本下写的,也就是说,转换源码调用了TensorFlow 1.x 下的工具包。这造成了兼容性问题。
兼容性修复后的代码:
import tensorflow as tf
tf.keras.backend.set_learning_phase(0)
model = tf.keras.models.load_model('./model.h5') # 加载h5模型
export_path = './pb_model'
with tf.compat.v1.keras.backend.get_session() as sess:
tf.compat.v1.saved_model.simple_save(
sess,
export_path,
inputs={'input_image': model.input},
outputs={t.name: t for t in model.outputs})
代码使用:
- 将训练好的模型
.h5
文件放到与该.py
文件相同目录下,修改model = tf.keras.models.load_model('./model.h5')
中的model.h5
为自己训练好的模型的文件名及格式名 - 输出的
.pd
文件位于.py
文件所在目录中的pb_model文件夹中
代码修改部分及原因分析:
-
原因:修复兼容性问题
-
解决方法:在tensorflow2.x中有个API就是为了兼容(compat) 某些tensorflow1.x版本和某些tensorflow2.x版本,该API就是:
tf.compat.v1
:是为了兼容tensorflow1.x中的某些APItf.compat.v2
:是为了兼容tensorflow2.x中的某些API
因此只要把:
sess = tf.keras.backend.get_session()
改为如下即可:
sess = tf.compat.v1.keras.backend.get_session()
参考:
源码转自:tensorflow将.h5模型转换成.pb
附上源码:
import tensorflow as tf
tf.keras.backend.set_learning_phase(0)
model = tf.keras.models.load_model('./mmm.h5')#加载h5模型
export_path = './pb_model'
with tf.keras.backend.get_session() as sess:
tf.saved_model.simple_save(
sess,
export_path,
inputs={'input_image': model.input},
outputs={t.name: t for t in model.outputs})
在源码运行时,有可能会出现找不到引用的报错:
在 '__init__.py' 中找不到引用 'get_session'
Pycharm中提示图为:
get_session
和simple_save
标黄,表示找不到包