##validation
import os
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
N_CLASSES = 2
IMG_W = 32 # resize
IMG_H = 32
BATCH_SIZE = 64
CAPACITY = 2000
MAX_STEP = 8000
valid, valid_label = get_files()
logs_valid_dir='./train/valid'
valid_batch,valid_label_batch=get_batch(valid,
valid_label,
IMG_W,
IMG_H,
BATCH_SIZE,
CAPACITY)
# x = tf.placeholder(tf.float32, [BATCH_SIZE,IMG_W,IMG_H,1])
# y = tf.placeholder(tf.float32, [BATCH_SIZE])
# y = tf.cast(y,tf.int64)
#with tf.Graph().as_default():
###上面这句困扰我很久的问题,需要注意,否则会一直出现 Tensor must be from the same graph as ###Tensor 的报错
ckpt_path = './train_single_channel/model.ckpt-4999999'
train_logits = inference(valid_batch, BATCH_SIZE, N_CLASSES)
valid__acc = evaluation(train_logits, valid_label_batch, 'valid')
val_recall, val_precision = recall_precision(train_logits, valid_label_batch, 'valid')
saver = tf.train.Saver()
with tf.Session() as sess:
#valid_batch,valid_label_batch=sess.run([valid_batch,valid_label_batch])
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# variables = tf.contrib.framework.get_variables_to_restore()
# variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='softmax_linear']
saver.restore(sess,ckpt_path)
print('--------restore done---------')
summary_op = tf.summary.merge_all()
valid_writer = tf.summary.FileWriter(logs_valid_dir, sess.graph)
for step in np.arange(MAX_STEP):
try:
for step in np.arange(MAX_STEP):
if coord.should_stop():
break
print('---------validation start----------')
val_acc, val_recall, val_precision = sess.run([valid__acc, val_recall, val_precision])
#val_acc, cal_recall, val_precision = sess.run([valid__acc, val_recall, val_precision],feed_dict={x:valid_batch,y:valid_label_batch})
print('Step %d, valid accuracy = %.2f%%, valid recall = %.2f%%, valid precision = %.2f%%' %(step, val_acc*100.0, val_recall*100.0, val_precision*100.0))
summary_str = sess.run(summary_op)
valid_writer.add_summary(summary_str, step)
except tf.errors.OutOfRangeError:
print('Done validation -- epoch limit reached')
finally:
coord.request_stop()