原文链接: tensorflow 线性模型保存为pb格式,并且在tfjs中使用
上一篇: tensorflowjs 简单使用
保存时需要指定,output_node_names 是准备保存的模型节点名称的列表,只会保存指定节点的数据
@tf_export("graph_util.convert_variables_to_constants")
def convert_variables_to_constants(sess,
input_graph_def,
output_node_names,
variable_names_whitelist=None,
variable_names_blacklist=None):
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
f.write(constant_graph.SerializeToString())
用于拟合直线y=2x-1的模型代码
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from tensorflow.python.framework import graph_util
# 模拟生成数据点, 返回np数组
def get_data(batch_size):
x = np.random.uniform(-10, 10, (batch_size, 1))
y = 2 * x - 1 + np.random.randn(batch_size, 1)
return x, y
SAVE_DIR