代码变动部分:
logits, train_op, loss, maintain_averages_op, accuracy = simplenet(x,y,class_num)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
steps = epochs * len(img_data) // batchsize
for step in range(steps):
batch_inputs = inputs[step*batchsize:(step+1)*batchsize]
batch_labels = true_labels[step*batchsize:(step+1)*batchsize]
ls, acc, _ = sess.run([loss,accuracy,maintain_averages_op],feed_dict={x:batch_inputs,y:batch_labels})
if step%100 == 0:
saver.save(sess,model_dir,global_step=step)
print(' step: ', step, ' loss: ', ls, ' accuracy: ', acc)
模型加载:
tf.train.import_meta_graph('./models/model-0.meta')
for variable_name in tf.global_variables():
print(variable_name)
for tensor_name in tf.contrib.graph_editor.get_tensors(tf.get_default_graph()):
print(tensor_name)
# with tf.Session() as sess:
# for node in sess.graph_def.node:
# print(node)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,'./models/model-100')
#print(sess.run(tf.get_default_graph().get_tensor_by_name('Variable_5:0')))