- 暂存模型(*.index为参数名称,*.meta为模型图,*.data*为参数)
tf.reset_default_graph()
weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
sess.close()
- 暂存模型(同一模型多次保存可以不保存模型图节省时间)
tf.reset_default_graph()
weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False)
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False)
sess.close()
- 恢复模型(手动生成网络则不需要*.meta文件)
tf.reset_default_graph()
weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
print(sess.run([weights]))
sess.close()
- 恢复模型(从*.meta文件生成网络)
tf.reset_default_graph()
saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
all_op = tf.get_default_graph().get_operations() #获取所有op
all_var = tf.all_variables() #获取所有var
print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))
sess.close()
- 恢复模型(可以在一个文件夹下保存多次模型,checkpoint文件会自动记录所有模型名称和最后一次记录模型名称)
tf.reset_default_graph()
weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)
print(sess.run([weights]))
sess.close()
- 微调模型(恢复之前训练模型的部分参数,加上新参数,继续训练)
def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True):
ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
vars_in_ckpt = {}
for variable_name, variable in sorted(variables.items()):
if variable_name in ckpt_vars_to_shape_map:
if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
vars_in_ckpt[variable_name] = variable
return vars_in_ckpt
tf.reset_default_graph()
weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
other_weights = tf.Variable(tf.zeros([10, 10]))
variables_to_init = tf.global_variables()
variables_to_init_dict = {var.op.name: var for var in variables_to_init}
available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict,
"%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False)
tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([weights]))
sess.close()
- 保存模型(二进制模型)
from tensorflow.python.framework.graph_util import convert_variables_to_constants
tf.reset_default_graph()
saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['weights'])
with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output:
output.write(graph_out.SerializeToString())
sess.close()
- 加载模型(二进制模型)
tf.reset_default_graph()
sess = tf.Session()
with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def,name='')
sess.run(tf.global_variables_initializer())
print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))
sess.close()
参考文献:
https://blog.csdn.net/loveliuzz/article/details/81661875