tf.get_default_graph() 与 tf.Graph()的区别

当项目运用到加载多个tf模型的时候,要慎用tf.get_default_graph(),因为可能会报错,比如:

NotFoundError (see above for traceback): Key bidirectional/backward_cu_dnnlstm/bias not found in checkpoint
         [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], 
         _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
         [[Node: save/RestoreV2/_157 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", 
         send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, 
         tensor_name="edge_171_save/RestoreV2", tensor_type=DT_FLOAT, 
         _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

这是因为第一个tf模型加载的时候,使用了tf.get_default_graph(),创建了全局性的默认graph。那么第二个tf模型加载的时候,就会使用同一个全局性的默认graph,也就是说,第二个tf模型用了第一个模型的graph,很显然,这样会报错。

正确的做法,应该是每个模型都用自己的graph去加载模型参数,而不应该用同一个全局性的默认graph。具体可以这样做:每个模型加载的时候,都使用tf.Graph()创建自己局部的graph,然后在自己的局部graph下去初始化模型和加载模型参数。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值