加载tf模型 正确率很低_tensorflow修改模型输入

CTR模型如果用estimator训练出来,直接导出为savedModel格式,上线时会有冗余计算,一般需要只计算模型的一个子图即可。

假设有下面这样一个计算图:

6b1baa68f86b6ba5a66f019d4ad5da99.png

其中:

a是一个placeholder

b、d是一个variable

c是一个tf.add节点

out是一个tf.multiply节点

一般情况下,我们保存了这个图,然后再加载,输入把a以feed_dict的形式传入参数就可以对图进行计算,现在的需求是,把其中任意一个节点替换掉,比如不计算a+b,直接给一个输入c;或者d不再是常量,而是运行是传入参数。这些需求都可以用tf的API实现。具体参考代码:

保存模型:

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

from tensorflow.python.framework.graph_util import convert_variables_to_constants

a = tf.placeholder(dtype=tf.float32, shape=(1, 2), name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')

c = tf.add(a, b, name='c')
d = tf.Variable(2, dtype=tf.float32, name='d')

out = tf.multiply(c, d, name='out')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = convert_variables_to_constants(sess, sess.graph_def, ['out'])
    tf.train.write_graph(graph, '.', 'graph_placeholder.pb', as_text=False)
    tf.train.write_graph(graph, '.', 'graph_placeholder_txt.pb', as_text=True)

加载模型:

import numpy as np
#import tensorflow as tf
from google.protobuf import text_format
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

with tf.Session() as sess:
    with open('./graph_placeholder.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        data = np.array([[3.],[2.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'a:0': data}, return_elements=['out:0'])
        print(sess.run(output))

with tf.Session() as sess:
    with open('./graph_placeholder.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        data = np.array([[3.],[2.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'c:0': 5.}, return_elements=['out:0'])
        print(sess.run(output))

with tf.Session() as sess:
    with open('./graph_placeholder_txt.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        text_format.Merge(f.read(), graph_def)
        data = np.array([[7.], [3.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'a:0': data, 'd:0': 9.}, return_elements=['out:0'])
        print(sess.run(output))

参考资料:

https://tang.su/2017/01/export-TensorFlow-network/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值