tf中线程与graph读取的关系

def import_graph_fun(pb_model_name):
    output_graph_def = tf.GraphDef()
    with open(pb_model_name, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
    sess = tf.Session()
    # other code

当我们训练好模型之后,将模型存储成pb格式,然后上面这段代码是读取该文件,一直以来都没有问题,直到有一天当我们需要在一个线程中运行多个模型的时候报错,报错的内容大概就是在预测阶段,不能从graph中找到对应节点。

之前我们都知道在tf框架中:1. 所有的graph都要在session中运行,并且一个session中只能运行一个graph,但是同样的graph可以在不同的session中运行; 2. 如果没有指定graph,框架会为我们生成默认的graph;3. 所有的op操作都会添加到模型的graph上。

其实进行到了这里,基本就能解开上面说的报错的原因:如果我们在一个线程中跑多个graph,我们就必须要有多个session,并且要给每个session绑定它对应的graph,之所以在一个线程中只跑一个图一直没出错是因为就一个图,默认这个图就是绑定了这个session的,不存在歧义。

其实我当时在写这里的时候有个疑问:这条语句graph_def.ParseFromString(f.read())是从pb文件中将序列化的graph解析出来,然后根据这条语句tf.import_graph_def(graph_def, name=“”)将这个graph_def导入,那么问题是这个导入是导入到了哪里?该接口并没有将graph return回来,它去了哪里,我们怎么拿到它?答案是:该接口直接将import出来的graph中所有的op添加到了它对应的上下文的graph中,想要获取它,就要先构造一个上下文环境,然后才能拿到这个graph,具体代码为:

def import_graph_fun(pb_model_name):
    output_graph_def = tf.GraphDef()
    with open(pb_model_name, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        
        # 注意:这条语句非常重要,通过接口as_default()构建了一个上下文环境,此时tf.import_graph_def就是将op添加到了这个环境所对应的graph中,也就是g_
        with tf.Graph().as_default() as g_:
            tf.import_graph_def(output_graph_def, name="")
    
    # 这里指定一下这个sess1绑定的是g_
    sess1 = tf.Session(graph=g_)
    # other code

 

如此,就可以work了。

总结:graph是跟线程相关的。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值