restore two tensorflow graph from the process
还按照普通的方法,
with tf.name_scope(‘blabla’):
with tf.variable_scope(tf.get_variable_scope(), reuse = False):
#构建图等
self.image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3])
config.eval_image_width, 3])
self.processed_image, _, _, _, _,self.pad_mod = ssd_vgg_preprocessing.preprocess_image(self.image, None, None, None, None,
out_shape = config.image_shape,
data_format = config.data_format,
is_training = False)
before_image = tf.cast(self.image, tf.float32)
before_image = tf.expand_dims(before_image, 0)
after_image = tf.expand_dims(self.processed_image, 0)
b_image = tf.expand_dims(self.processed_image, axis = 0)
# build model and loss
self.net = pixel_link_symbol.PixelLinkNet(b_image,config=config)
self.masks = pixel_link.tf_decode_score_map_to_mask_in_batch(
self.net.pixel_pos_scores, self.net.link_pos_scores)
#然后,得到变量列表
#这里如果以前构建了多个图,这里将把所有的变量都列出来。因为这里用的是默认GraphKeys.GLOBAL_VARIABLES,这里的所有变量
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(var_list = variables_to_restore)
self.sess = tf.Session()
saver.restore(self.sess, checkpath)
那么应该怎么办:
#如果有多个图,就每个使用一个tf.Graph()来包裹
g = tf.Graph()
with g.as_default():
with tf.name_scope(‘blabla’):
with tf.variable_scope(tf.get_variable_scope(), reuse = False):
#构建图等
self.image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3])
config.eval_image_width, 3])
self.processed_image, _, _, _, _,self.pad_mod = ssd_vgg_preprocessing.preprocess_image(self.image, None, None, None, None,
out_shape = config.image_shape,
data_format = config.data_format,
is_training = False)
before_image = tf.cast(self.image, tf.float32)
before_image = tf.expand_dims(before_image, 0)
after_image = tf.expand_dims(self.processed_image, 0)
b_image = tf.expand_dims(self.processed_image, axis = 0)
# build model and loss
self.net = pixel_link_symbol.PixelLinkNet(b_image,config=config)
self.masks = pixel_link.tf_decode_score_map_to_mask_in_batch(
self.net.pixel_pos_scores, self.net.link_pos_scores)
#然后,得到变量列表
#这里如果以前构建了多个图,这里将把所有的变量都列出来。因为这里用的是默认GraphKeys.GLOBAL_VARIABLES,这里的所有变量
#variables_to_restore = slim.get_variables_to_restore()
variables_to_restore = g.get_collection('variables')
saver = tf.train.Saver(var_list = variables_to_restore)
self.sess = tf.Session(graph=g)#这里要用上面构建的图
saver.restore(self.sess, checkpath)