TensorFlow加载多个模型

通常我们在开发中根据不同任务需要不同的预训练模型,因此需要同时加载多个模型文件。但是同时加载多个TensorFlow预训练模型时,若还是采用加载单个模型文件一样的方式则会因图冲突而加载失败。主要是因为不同对象里面的不同sess使用了同一进程空间下的相同的默认图graph。 因此,我们需要为为每个类(实例)单独创建一个graph

g1 = tf.Graph() #为每个类(实例)单独创建一个graph
g2 = tf.Graph()

sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
#亲测,若你训练模型时指定了设备,如上一行代码,则你restore时也要加上,不然会出错。
sess1 = tf.Session(graph=g1, config=sess_config)
sess2 = tf.Session(graph=g2, config=sess_config)

#加载模型1,
with sess1.as_default():
        with sess1.graph.as_default():
            tf.global_variables_initializer().run()
            model_saver = tf.train.import_meta_graph(model_path_1+'model.meta')
            model_cpt = tf.train.get_checkpoint_state(model_path_1)
            model_saver.restore(sess1, model_cpt.model_checkpoint_path)
            graph = tf.get_default_graph()
            x1 = graph.get_tensor_by_name('x1:0')
            x2 = graph.get_tensor_by_name('keep_prob:0')
            y1 = graph.get_tensor_by_name('y1:0')

#加载模型2
with sess2.as_default():
        with sess2.graph.as_default():
            tf.global_variables_initializer().run()
            model_saver = tf.train.import_meta_graph(model_path_2+'model.meta')
            model_ckpt = tf.train.get_checkpoint_state(model_path_2)
            model_saver.restore(sess2, model_ckpt.model_checkpoint_path)
            graph = tf.get_default_graph()
            x1 = graph.get_tensor_by_name('x1:0')
            x2 = graph.get_tensor_by_name('x2:0')
            x3 = graph.get_tensor_by_name('keep_prob:0')
            y2 = graph.get_tensor_by_name('y2:0')

#使用模型1
with sess1.as_default():
        with sess1.graph.as_default():
            feed = {x1: a, x2: 1.0}
            res1 = sess1.run(y1, feed_dict=feed)
            .....

#使用模型2
with sess2.as_default():
        with sess2.graph.as_default():
            feed = {x1: b, x2: c,x3:1.0}
            res2 = sess2.run(y2, feed_dict=feed)
            .....

使用上述方式时,你在定义模型时一定要给变量命名,不然不方便获取tensor。此种方式不需要你重新定义模型。

我在实际操作过程中,试了另一种方式:还是要为每个实例单独建图,但是在加载时你再次定义模型。这种方式不用你一个一个的根据变量名来获取tensor。

g1 = tf.Graph() #为每个类(实例)单独创建一个graph
g2 = tf.Graph()

sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)

sess1 = tf.Session(graph=g1, config=sess_config)
sess2 = tf.Session(graph=g2, config=sess_config)

#加载模型1
with sess1.as_default():
        with sess1.graph.as_default():
            #此处与加载单模型方式一样,但你得再次定义模型。
            #若你以前的模型文件是定义为类的,则此处非常方便
            #例如我的模型
            model1 = CNNModel1(config.MAX_SEQ_LENGTH,
                                 np.array(embedded),
                                 config.EMBEDDING_DIM,
                                 config.KERNEL_SIZE,
                                 config.NUM_FILTER,
                                 config.DROPOUT_KEEP_PROB,
                                 config.MARGIN_VALUE,
                                 config.LEARNING_RATE,
                                 config.L2_REG)
            #接下来的就很熟悉了
            sess1.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess1, model_path_1+'model')

#加载模型2
with sess2.as_default():
        with sess2.graph.as_default():
            model2 = CNNModel2(config.MAX_SEQ_LENGTH,
                                 np.array(embedded),
                                 config.EMBEDDING_DIM,
                                 config.KERNEL_SIZE,
                                 config.NUM_FILTER,
                                 config.DROPOUT_KEEP_PROB,
                                 config.MARGIN_VALUE,
                                 config.LEARNING_RATE,
                                 config.L2_REG)
            #接下来的就很熟悉了
            sess2.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess2, model_path_2+'model')

#使用模型的方法与之前是一样的。
with sess1.as_default():
        with sess1.graph.as_default():
            feed1 = {model1.x:a,model1.keep_prob:1.0}
            res1 = sess1.run(model1.y1, feed_dict=feed1)
            ....

with sess2.as_default():
        with sess2.graph.as_default():
            feed2 = {model1.x1:a,model.x2:b,model1.keep_prob:1.0}
            res2 = sess2.run(model1.y2, feed_dict=feed1)
            ....
            

以上表述全为实际测试得到的经验,若有错误欢迎指出。若大家还有其它方案,也欢迎讨论。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值