通常我们在开发中根据不同任务需要不同的预训练模型,因此需要同时加载多个模型文件。但是同时加载多个TensorFlow预训练模型时,若还是采用加载单个模型文件一样的方式则会因图冲突而加载失败。主要是因为不同对象里面的不同sess使用了同一进程空间下的相同的默认图graph。 因此,我们需要为为每个类(实例)单独创建一个graph
g1 = tf.Graph() #为每个类(实例)单独创建一个graph
g2 = tf.Graph()
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
#亲测,若你训练模型时指定了设备,如上一行代码,则你restore时也要加上,不然会出错。
sess1 = tf.Session(graph=g1, config=sess_config)
sess2 = tf.Session(graph=g2, config=sess_config)
#加载模型1,
with sess1.as_default():
with sess1.graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.import_meta_graph(model_path_1+'model.meta')
model_cpt = tf.train.get_checkpoint_state(model_path_1)
model_saver.restore(sess1, model_cpt.model_checkpoint_path)
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name('x1:0')
x2 = graph.get_tensor_by_name('keep_prob:0')
y1 = graph.get_tensor_by_name('y1:0')
#加载模型2
with sess2.as_default():
with sess2.graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.import_meta_graph(model_path_2+'model.meta')
model_ckpt = tf.train.get_checkpoint_state(model_path_2)
model_saver.restore(sess2, model_ckpt.model_checkpoint_path)
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name('x1:0')
x2 = graph.get_tensor_by_name('x2:0')
x3 = graph.get_tensor_by_name('keep_prob:0')
y2 = graph.get_tensor_by_name('y2:0')
#使用模型1
with sess1.as_default():
with sess1.graph.as_default():
feed = {x1: a, x2: 1.0}
res1 = sess1.run(y1, feed_dict=feed)
.....
#使用模型2
with sess2.as_default():
with sess2.graph.as_default():
feed = {x1: b, x2: c,x3:1.0}
res2 = sess2.run(y2, feed_dict=feed)
.....
使用上述方式时,你在定义模型时一定要给变量命名,不然不方便获取tensor。此种方式不需要你重新定义模型。
我在实际操作过程中,试了另一种方式:还是要为每个实例单独建图,但是在加载时你再次定义模型。这种方式不用你一个一个的根据变量名来获取tensor。
g1 = tf.Graph() #为每个类(实例)单独创建一个graph
g2 = tf.Graph()
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
sess1 = tf.Session(graph=g1, config=sess_config)
sess2 = tf.Session(graph=g2, config=sess_config)
#加载模型1
with sess1.as_default():
with sess1.graph.as_default():
#此处与加载单模型方式一样,但你得再次定义模型。
#若你以前的模型文件是定义为类的,则此处非常方便
#例如我的模型
model1 = CNNModel1(config.MAX_SEQ_LENGTH,
np.array(embedded),
config.EMBEDDING_DIM,
config.KERNEL_SIZE,
config.NUM_FILTER,
config.DROPOUT_KEEP_PROB,
config.MARGIN_VALUE,
config.LEARNING_RATE,
config.L2_REG)
#接下来的就很熟悉了
sess1.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess1, model_path_1+'model')
#加载模型2
with sess2.as_default():
with sess2.graph.as_default():
model2 = CNNModel2(config.MAX_SEQ_LENGTH,
np.array(embedded),
config.EMBEDDING_DIM,
config.KERNEL_SIZE,
config.NUM_FILTER,
config.DROPOUT_KEEP_PROB,
config.MARGIN_VALUE,
config.LEARNING_RATE,
config.L2_REG)
#接下来的就很熟悉了
sess2.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess2, model_path_2+'model')
#使用模型的方法与之前是一样的。
with sess1.as_default():
with sess1.graph.as_default():
feed1 = {model1.x:a,model1.keep_prob:1.0}
res1 = sess1.run(model1.y1, feed_dict=feed1)
....
with sess2.as_default():
with sess2.graph.as_default():
feed2 = {model1.x1:a,model.x2:b,model1.keep_prob:1.0}
res2 = sess2.run(model1.y2, feed_dict=feed1)
....
以上表述全为实际测试得到的经验,若有错误欢迎指出。若大家还有其它方案,也欢迎讨论。