tf模型保存为pb文件,读取时出现ValueError: Input 0 of node bilm/Assign was passed float from bilm/Variable:0 incompatible with expected float_ref.
运行读取pb文件的操作:
tf.import_graph_def(graph_def, name='')
会抛出类似如下的异常:
ValueError: Input 0 of node bilm/Assign was passed float from bilm/Variable:0 incompatible with expected float_ref.
原因是在froze模型的过程中,Assign操作会把float_reg类型转为float型,我们需要做的是,在tensorflow的deprecation.py的文件507行后,加入:
with open('D:\\tools\\s.log', 'w') as f:
f.write(str(args))
并在输出的log文件中,查找涉及Assign的操作(op),另外,在保存pb文件的代码中,加入转换操作的代码:
with tf.Session() as sess:
gd = sess.graph.as_graph_def()
sess.run(tf.global_variables_initializer())
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'Assign':
node.op = 'Identity'
if 'use_locking' in node.attr: del node.attr['use_locking']
if 'validate_shape' in node.attr: del node.attr['validate_shape']
if len(node.input) == 2:
# input0: ref: Should be from a Variable node. May be uninitialized.
# input1: value: The value to be assigned to the variable.
node.input[0] = node.input[1]
del node.input[1]
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
from tensorflow.python.framework import graph_util
converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, ['pred_ids'])
tf.train.write_graph(converted_graph_def, out_dir, out_name, as_text=False)
即可解决以上问题。