下面这段代码是模型存储的代码,这是随便参考的1个例子
import tensorflow as tf
# First, you design your mathematical operations
# We are the default graph scope
# Let's design a variable
v1 = tf.Variable(1. , name="v1")
print('v1:',v1)
v2 = tf.Variable(2. , name="v2")
# Let's design an operation
a = tf.add(v1, v2)
# Let's create a Saver object
# By default, the Saver handles every Variables related to the default graph
all_saver = tf.train.Saver()
# But you can precise which vars you want to save under which name
v2_saver = tf.train.Saver({"v2": v2})
# By default the Session handles the default graph and all its included variables
with tf.Session() as sess:
# Init v and v2
sess.run(tf.global_variables_initializer())
# Now v1 holds the value 1.0 and v2 holds the value 2.0
# We can now save all those values
all_saver.save(sess, 'test_model/data.chkp')
此时我们想恢复模型,于是写了以下代码:
import tensorflow as tf
saver = tf.train.import_meta_graph('test_model/data.chkp.meta')
graph = tf.get_default_graph()
global_step_tensor = graph.get_tensor_by_name('v1:0')
with tf.Session() as sess:
saver.restore(sess, 'test_model/data.chkp.data-00000-of-00001')
print(sess.run(global_step_tensor))
结果却报错!!!如下
DataLossError (see above for traceback): Unable to open table file test_model/data.chkp.data-00000-of-00001: Data loss: file is too short to be an sstable: perhaps your file is in a different file format and you need to use a different restore operator?
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT], _device=”/job:localhost/replica:0/task:0/device:CPU:0”](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
[[Node: save/RestoreV2/_1 = _Recvclient_terminated=false, recv_device=”/job:localhost/replica:0/task:0/device:GPU:0”, send_device=”/job:localhost/replica:0/task:0/device:CPU:0”, send_device_incarnation=1, tensor_name=”edge_6_save/RestoreV2”, tensor_type=DT_FLOAT, _device=”/job:localhost/replica:0/task:0/device:GPU:0”]]
那么怎么解决呢?其实很简单,只要将.data-00000-of-00001去掉即可!!
import tensorflow as tf
saver = tf.train.import_meta_graph('test_model/data.chkp.meta')
graph = tf.get_default_graph()
global_step_tensor = graph.get_tensor_by_name('v1:0')
with tf.Session() as sess:
saver.restore(sess, 'test_model/data.chkp')
print(sess.run(global_step_tensor))