首先创建两个图表
model_graph = tf.Graph()
with model_graph.as_default():
model = Model(args)
adv_graph = tf.Graph()
with adv_graph.as_default():
adversary = Adversary(adv_args)
然后建立两个会话
adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)
然后在每个会话中初始化变量并分别恢复每个图形
with sess.as_default():
with model_graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
with adv_sess.as_default():
with adv_graph.as_default():
tf.global_variables_initializer().run()
adv_saver = tf.train.Saver(tf.global_variables())
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
这里每当需要每个会话时,用sess.as_default()包装该会话中的任何tf函数:最后手动关闭会话
sess.close()
adv_sess.close()
转载:http://www.voidcn.com/article/p-wgywhvux-bvm.html