初始目录结构
save.py代码
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys
np.random.seed(1)
data_x = np.random.rand(100, 1)
np.random.seed(2)
data_y = np.random.rand(100, 1)
save_dir_path = 'model'
save_file_name = 'model.cpkt'
with tf.name_scope('myPlaceholder') as scope:
x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')
def model():
return tf.add(tf.multiply(a, x), b, name="linear_model")
def process():
for _ in tqdm(range(1000)):
_, value_a, value_b, value_loss = sess.run([train, a, b, loss_function], feed_dict=feed_dict_x)
print('训练之后', value_a, value_b, 'loss', value_loss)
saver.save(sess, save_path=os.path.join(save_dir_path, save_file_name))
if __name__ == '__main__':
"""
如果不是gpu,将config去掉,使用默认的tf.Session()创建session
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
saver = None
if not tf.train.checkpoint_exists(save_dir_path):
os.mkdir(save_dir_path)
a = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='a')
b = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='b')
saver = tf.train.Saver(max_to_keep=1)
sess.run(tf.global_variables_initializer())
model_output = model()
loss_function = tf.reduce_mean(tf.square(model_output - data_y), name='loss')
train = tf.train.GradientDescentOptimizer(learning_rate=0.002).minimize(loss_function, name='minimize')
feed_dict_x = {x: data_x}
else:
saver = tf.train.import_meta_graph(os.path.join(os.getcwd(), save_dir_path, "model.cpkt.meta"))
saver.restore(sess, os.path.join(save_dir_path, save_file_name))
graph = tf.get_default_graph()
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
print('恢复模型', sess.run(a), sess.run(b))
loss_function = graph.get_tensor_by_name('loss:0')
train = graph.get_operation_by_name('minimize')
feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x_1:0"): data_x}
op = input("""选择继续训练或者使用模型进行预测(1:训练 2:预测)""")
if op is '1':
pass
elif op is '2':
while True:
input_x = np.array([[input("输入x:")]], dtype=np.float64)
feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x_1:0"): input_x}
output = sess.run(graph.get_tensor_by_name("linear_model:0"), feed_dict=feed_dict_x)
print(output)
else:
sys.exit()
process()
sess.close()
上面的恢复模型是通过加载已经持久化的图,而下面的是通过已经定义图上的运算。区别在于使用上面的代码恢复模型时,即使是 注释掉了model()方法,依旧能正常运行,因为不依靠已经定义好的运算,下面的代码在恢复模型时,只把变量的值加载了进来,需要重复定义图上的运算。
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys
np.random.seed(1)
data_x = np.random.rand(100, 1)
np.random.seed(2)
data_y = np.random.rand(100, 1)
save_dir_path = 'model'
save_file_name = 'model.cpkt'
def model():
return tf.add(tf.multiply(a, x), b, name="linear_model")
def process():
for _ in tqdm(range(1000)):
_, value_a, value_b, value_loss = sess.run([train, a, b, loss_function], feed_dict=feed_dict_x)
print('训练之后', value_a, value_b, 'loss', value_loss)
saver.save(sess, save_path=os.path.join(save_dir_path, save_file_name))
if __name__ == '__main__':
"""
如果不是gpu,将config去掉,使用默认的tf.Session()创建session
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
with tf.name_scope('myPlaceholder') as scope:
x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')
a = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='a')
b = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='b')
model_output = model()
loss_function = tf.reduce_mean(tf.square(model_output - data_y), name='loss')
train = tf.train.GradientDescentOptimizer(learning_rate=0.002).minimize(loss_function, name='minimize')
feed_dict_x = {x: data_x}
saver = tf.train.Saver(max_to_keep=1)
if not tf.train.checkpoint_exists(save_dir_path):
os.mkdir(save_dir_path)
sess.run(tf.global_variables_initializer())
else:
saver.restore(sess, os.path.join(save_dir_path, save_file_name))
op = input("""选择训练或者使用模型进行预测(1:训练 2:预测)""")
if op is '1':
pass
elif op is '2':
while True:
input_x = np.array([[input("输入x:")]], dtype=np.float64)
feed_dict_x = {x: input_x}
output = sess.run(model(), feed_dict=feed_dict_x)
print(output)
else:
sys.exit()
process()
sess.close()