【笔记】tf 模型 不同类型之间的转化:ckpt 转pb、mlmodel;mlmodel转mlmodelc;pb转mnn

该博客介绍了如何将TensorFlow模型转换为.pb文件,再进一步转换为iOS兼容的.mlmodel,并且讨论了从.pb转换到MNN格式的过程。主要涉及模型冻结、Core ML模型精度调整以及MNN模型转换工具的使用。
摘要由CSDN通过智能技术生成

注:

 

正文:

1. 转pb和mlmodel

# csdn -牧野- 2020-7-18
# https://blog.csdn.net/dcrmg/article/details/107213367
import tensorflow as tf
from tensorflow import graph_util
import os
import tfcoreml
import network
 
def freeze_graph(input_checkpoint, output_graph, output_node_names = 'DepthToSpace'):
    in_image = tf.placeholder(tf.float32, [1, 512, 512, 4], 'xx')
    output = network_XX(in_image)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,
            output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
        for op in graph.get_operations():
            print(op.name, op.values())
 
def convert_CoreMLModel_DOUBLE2FLOAT(mlmodel_save_path):
    def update_multiarray_to_float32(feature):
        if feature.type.HasField('multiArrayType'):
            import coremltools.proto.FeatureTypes_pb2 as _ft
            feature.type.multiArrayType.dataType = _ft.ArrayFeatureType.FLOAT32
 
    input_model_path = mlmodel_save_path
    output_model_path = mlmodel_save_path
 
    spec = coremltools.utils.load_spec(input_model_path)
 
    for input_feature in spec.description.input:
        update_multiarray_to_float32(input_feature)
 
    for output_feature in spec.description.output:
        update_multiarray_to_float32(output_feature)
 
    coremltools.utils.save_spec(spec, output_model_path)
 
 
if __name__ == '__main__':
    checkpoint_folder = './checkpoint'
    checkpoint_name = 'model_xx.ckpt'
    pb_name = 'xx.pb'
    mlmodel_name = 'xx.mlmodel'
 
    mlmodel_save_path = os.path.join(checkpoint_folder, 'mlmodel', mlmodel_name)
    pb_save_path = os.path.join(checkpoint_folder, 'pb', pb_name)
 
    if not os.path.exists(os.path.join(checkpoint_folder, 'pb')):
        os.mkdir(os.path.join(checkpoint_folder, 'pb'))
    if not os.path.exists(os.path.join(checkpoint_folder, 'mlmodel')):
        os.mkdir(os.path.join(checkpoint_folder, 'mlmodel'))
 
    input_checkpoint = os.path.join(checkpoint_folder, checkpoint_name)
 
    # freeze model to pb first
    ### Netron may be helpfull to find input name and output name
    freeze_graph(input_checkpoint, pb_save_path, output_node_names='xx')
 
    # convert pb to mlmodel(ios 12) (hegith and width must not be -1)
    tfcoreml.convert(tf_model_path=pb_save_path,
                 mlmodel_path=mlmodel_save_path,
                 output_feature_names=['xx:0'],  # name of the output tensor (appended by ":0")
                 input_name_shape_dict={'xx:0': [1, 512, 512, 4]},  # input tensor[1, height, width, channel]
                 minimum_ios_deployment_target='12')
 
    # (ios 13) (hegith and width could be -1)
    # tfcoreml.convert(tf_model_path=pb_save_path,
    #              mlmodel_path=mlmodel_save_path,
    #              output_feature_names=['DepthToSpace:0'],  # name of the output tensor (appended by ":0")
    #              input_name_shape_dict={'Placeholder:0': [1, -1, -1, 4]},  # input tensor[1, height, width, channel]
    #              minimum_ios_deployment_target='13')
    convert_CoreMLModel_DOUBLE2FLOAT(mlmodel_save_path)

2. mlmodel转mlmodelc

xcrun coremlc compile xx.mlmodel

3. pb转mnn

MNNConvert -f TF --modelFile XXX.pb --MNNModel XXX.mnn --bizCode biz
或
MNNConvert -f TF --modelFile XXX.pb --MNNModel XXX.mnn --bizCode mnn

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值