tensorflow冻结模型为pb文件的各种方法

笔者最近因为工作需要将TensorFlow训练模型迁移到晟腾芯片平台上离线推理,所以需要将ckpt或者h5模型冻结成pb,再利用ATC模型转换工具将pb转为离线模型om文件,期间遇到一些问题和坑,总结一下,供大家参考。

1.Tensorflow1.x

训练好的模型Ckpt文件:

DB_resnet_v1_50_adam_model.ckpt-16801.index
DB_resnet_v1_50_adam_model.ckpt-16801.data-00000-of-00001
DB_resnet_v1_50_adam_model.ckpt-16801.meta

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值

对于这种文件,转换成pb的方法,大家应该很熟悉了,在csdn上面可以搜到很多现成的转换代码,例如:
tf1.x ckpt转pb文件
核心思想就是找到输出节点,然后冻结pb。
本文介绍一种简单的方法,可以利用tf自带的冻结脚本,这里以tf1.15为例(不同版本py文件路径有可能不能):

python3.7.5 /usr/local/python3.7.5/lib/python3.7/site-packages/tensorflow_core/python/tools/freeze_graph.py  \
--input_checkpoint=./iwslt2016_E19L2.64-29146 \
--output_graph=./transformer.pb \
--output_node_names="transformer/strided_silice_8" \
--input_meta_graph=./iwslt2016_E19L2.64-29146.meta  \
--input_binary=true

参考以上命令,根据你自己的ckpt的命名和模型输出节点修改一下,执行命令即可得到冻结的pb文件。

2.Tensorflow2.x

Keras现在是一个非常流行的工具库,包括tensorflow已经把keras合并到了自己的主代码当中了,大家可以直接tf.keras就可以直接调用其中的工具库。

如果是用Keras训练得到的模型,我们想移植到昇腾芯片上运行,那么就需要先把模型固化成TF1.x的pb格式,然后才能使用ATC模型转换工具转成om离线模型。
开始转换之前,需要确认Keras保存的模型文件(hdf5或者h5)是完整的模型文件还是仅保存权重文件,二者使用的接口不一样:

model.save(“xxx.h5”) ------保存的是完整的模型文件,即模型结构+权重文件

model.save_weights(“xxx.h5”)------保存的仅仅是权重文件,还需要调用model.to_json()来保存模型结构到一个json文件中

对于这样的h5文件转pb模型,我们可以参考:
https://github.com/amir-abdi/keras_to_tensorflow

2.1、对于保存完整模型的h5文件:
python3 keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --output_model="path/to/save/model.pb"
2.2、对于模型权重和结构分开保存的文件:
python3 keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --input_model_json="path/to/keras/model.json" 
    --output_model="path/to/save/model.pb"

另外还有一种情况就是模型保存的时候传入的是目录:

model.save("./checkpoint")

这时候保存的就是checkpoint文件,不同于tf1.x的checkpoint,这个目录下保存的是pb+若干子目录,不能使用tf1.x固化checkpoint的方法来生成pb文件。

对于这种ckpt文件,可以参考我的方法来生成pb:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import pdb

new_model = tf.keras.models.load_model('./Wide_ResNet/')

full_model = tf.function(lambda x: new_model(x))
full_model = full_model.get_concrete_function(x=tf.TensorSpec((None,32,32,3),'float32'))

forzen_func = convert_variables_to_constants_v2(full_model)
forzen_func.graph.as_graph_def()

layers = [op.name for op in forzen_func.graph.get_operations()]
print("-"*50)
print("Frozen model layers:")
for layer in layers:
    print(layer)

print("*"*50)
print("Frozen model input:")
print(forzen_func.inputs)
print("Frozen model output:")
print(forzen_func.outputs)

tf.io.write_graph(
    graph_or_graph_def=forzen_func.graph,
    logdir="./",
    name="WRN.pb",
    as_text=False
)

full_model = full_model.get_concrete_function(x=tf.TensorSpec((None,32,32,3),‘float32’))--------输入的shape请根据自己的网络进行修改。

3.常见问题:

1、Keras加载模型时报错找不到自定义的层?

解决方法:找到训练代码中对应的自定义层代码,在加载模型时引入自定义层即可。

例如这样:

from keras.models import load_model
model = load_model('model.h5', custom_objects={'SincConv1D': SincConv1D})
2、一些特殊的网络,比如Retinanet,单独定义了第三方库,使用load_model加载模型失败?

解决方法:使用对应第三方库的模型加载接口来加载模型,然后再转pb

keras_retinanet.models.loadmodel("xxx.h5", backbone_name='resnet50')
  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值