Keras训练的h5模型转Tensorflow pb模型OpenCV可调用

环境:Keras 2.2.4

           Tensorflow-gpu 1.12

这里的h5模型是由keras训练保存的,注意不是tf.keras。因为如果是tf.keras训练生成的模型,那可能没这么多坑了。

也就是用的

from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input, Flatten, Conv2D,Softmax
   
   

这样构建的网络。

注意保存模型时要使用model.save(),如果使用model.save_weights()仅保存了权重,要恢复还要知道网络结构。

这里是保存的h5模型,注意到有BatchNormalization这样的操作,多半的坑都是因为它。

坑:

加载h5模型时使用的是


   
   
  1. from keras.models import load_model
  2. my_model = load_model( 'my_model.h5')

然后按照后面的流程也可以成功得到pb模型文件,但是opencv不能成功加载模型,看一下这个模型

和BatchNormalization有关的节点明显改变了。

看一下两个模型在BatchNormalization这个节点的不同。

如果用opencv的dnn来加载这个pb文件,会报错

cv2.error: OpenCV(4.1.1) C:\projects\opencv-python\opencv\modules\dnn\src\tensorflow\tf_importer.cpp:582: error: (-2:Unspecified error) Input [batch_normalization_1/ones_like] for node [batch_normalization_1/FusedBatchNorm_1] not found in function 'cv::dnn::dnn4_v20190621::`anonymous-namespace'::TFImporter::getConstBlob'
   
   

从这个log也可以推断是和BatchNormalization有关。

排坑:

 一定要用tensorflow.keras加载h5文件。


   
   
  1. from tensorflow.keras as keras
  2. my_model = keras.models.load_model( 'my_model.h5')

因为如果用原生keras加载模型会使用keras.engine.sequential.Sequential,这可能会导致一些object无法直接转换成tensorflow pb文件。

可以看到用tensorflow keras保存的模型格式已经改变了。

所以完整的流程应该是:

  1. keras训练,保存h5文件
  2. tensorflow keras加载h5文件
  3. 转换成tensorflow pb文件
  4. opencv dnn调用pb文件

   
   
  1. import tensorflow.keras as keras
  2. import tensorflow as tf
  3. import os
  4. #这个函数参考自网上
  5. def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
  6. """
  7. Freezes the state of a session into a pruned computation graph.
  8. Creates a new computation graph where variable nodes are replaced by
  9. constants taking their current value in the session. The new graph will be
  10. pruned so subgraphs that are not necessary to compute the requested
  11. outputs are removed.
  12. @param session The TensorFlow session to be frozen.
  13. @param keep_var_names A list of variable names that should not be frozen,
  14. or None to freeze all the variables in the graph.
  15. @param output_names Names of the relevant graph outputs.
  16. @param clear_devices Remove the device directives from the graph for better portability.
  17. @return The frozen graph definition.
  18. """
  19. graph = session.graph
  20. with graph.as_default():
  21. freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  22. output_names = output_names or []
  23. output_names += [v.op.name for v in tf.global_variables()]
  24. input_graph_def = graph.as_graph_def()
  25. if clear_devices:
  26. for node in input_graph_def.node:
  27. node.device = ''
  28. frozen_graph = tf.graph_util.convert_variables_to_constants(
  29. session, input_graph_def, output_names, freeze_var_names)
  30. return frozen_graph
  31. if __name__ == '__main__':
  32. input_path = 'D:\\training'
  33. #keras训练保存的h5文件
  34. input_file = 'my_model.h5'
  35. weight_file_path = os.path.join(input_path, input_file)
  36. output_graph_name = weight_file[: -3] + '.pb'
  37. # 加载模型
  38. keras.backend.set_learning_phase( 0)
  39. h5_model = keras.models.load_model(weight_file_path)
  40. frozen_graph = freeze_session(keras.backend.get_session(), output_names=[out.op.name for out in h5_model.outputs])
  41. tf.train.write_graph(frozen_graph, input_path, output_graph_name, as_text= False)
  42. print( 'Finished')
  43. import cv2
  44. model = cv2.dnn.readNetFromTensorflow( "D:\\training\\my_model.pb")
  45. print( 'Load')

通过这种方式就可以正确转换成pb文件了,并且这个pb文件opencv也是可以成功调用的了。

但是转换之后模型变成了

似乎是有一些在推理中不必要的节点。

可以通过下面的方式对pb模型进行进一步优化

python -m tensorflow.python.tools.optimize_for_inference --input my_model.pb --output my_model_opt.pb --input_names=input_1 --output_names=dense_1/Softmax
   
   

优化之后模型的体积小了不少,那些多出来的节点也没有了,另外BatchNormalization这样在推理中不需要的节点也没有了。

 

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值