TensorFlow estimator加载pb模型预测TfRecord中的样本

本文介绍了如何使用TensorFlow Estimator加载导出的PB模型,并从TfRecord文件中获取数据进行预测。在读取TfRecord文件时,需要启动Session以获取数据值。解析tf record部分涉及DataReader对象来管理数据格式。在拼接模型输入时,需要注意字段名称和维度匹配,避免出现维度错误。
摘要由CSDN通过智能技术生成

TensorFlow estimator加载pb模型预测TfRecord中的样本

最近在上一个模型,需要测试模型输出与线上server predict值的一致性,模型用到了高级API estimator以及tf record格式的数据,很多地方还不是很熟悉。为了测一致性我构造了一个只有一条样本的数据,然后想在本地加载模型export出来的pb格式文件,再load 数据输出预测值。在我的场景下是要捞模型某一层的数据,具体实现如下:

    def predict_with_modelpb(self, flags_obj, filenames):
        """
        filenames: 要读取的文件名list
        """

        if len(filenames) == 0:
            tf.logging.error("can not find any input files!")
            return None
        print(filenames)
        # 要加载的pb模型路径
        export_dir = "./model-local/export/saved_model/1584341279"

        self.data_reader = DataReader(None, flags_obj.feature_config_path)
        with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            tf.saved_model.loader.load(session, ["serve"], export_dir)
            tf.get_default_graph()
            # 从模型结构output中获取输出层name
            y = session.graph.get_tensor_by_name('bn4_1/LeakyRelu:0')
            # 从tf record中读取数据
            options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
            filename_queue = tf.train.string_input_producer([filenames[0]], )
            reader = tf.TFRecordReader(name='example_reader', options=options)
            _, serialized_example = reader.read(filename_queue)
            feature_dict = self.data_reader.build_example
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值