tensorflow保存和加载模型
构建手写数字识别神经网络的时候对输入输出值设置name分别为input和output
x_data = tf.placeholder(tf.float32, [None, 784], name = 'input')
# 模型输入节点,name = 'input'
y_data = tf.placeholder(tf.float32, [None, 10])
y = (w * x + b, name = 'output')
# 模型输出节点,name = 'input',函数简单表示为y=w*x+b的形式
对训练出来的模型进行保存
saver = tf.train.Saver() # 导入模型保存类
saver.save(sess, '模型保存路径\\model') # sess是当前会话
模型调用
tf.reset_default_graph()
# 首先重置计算图
saver = tf.train.import_meta_graph('模型保存路径\\model.meta')
# 导入保存的计算图
saver.restore(sess, '模型保存路径\\model')
# 重启会话sess,激活model及所有变量
gh = tf.get_default_graph()
# 获取当前默认的计算图
input = gh.get_tensor_by_name('input:0')
# 导入模型输入节点,input:0代表导入模型中第一个name='input'的变量
output = gh.get_tensor_by_name('output:0') # 模型输出节点
然后在input里面传入数据,经过模型的计算output得出结果