接上篇 || 模型的调用并验证(tensorflow)

##validation
import os
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
N_CLASSES = 2 
IMG_W = 32  # resize
IMG_H = 32
BATCH_SIZE = 64
CAPACITY = 2000
MAX_STEP = 8000

valid, valid_label = get_files()
logs_valid_dir='./train/valid'

valid_batch,valid_label_batch=get_batch(valid,
                                valid_label,
                                IMG_W,
                                IMG_H,
                                BATCH_SIZE,
                                CAPACITY)
# x = tf.placeholder(tf.float32, [BATCH_SIZE,IMG_W,IMG_H,1])
# y = tf.placeholder(tf.float32, [BATCH_SIZE])
# y = tf.cast(y,tf.int64)

#with tf.Graph().as_default():
###上面这句困扰我很久的问题,需要注意,否则会一直出现 Tensor must be from the same graph as ###Tensor 的报错
ckpt_path = './train_single_channel/model.ckpt-4999999'
    
train_logits = inference(valid_batch, BATCH_SIZE, N_CLASSES)
valid__acc = evaluation(train_logits, valid_label_batch, 'valid')
val_recall, val_precision = recall_precision(train_logits, valid_label_batch, 'valid')
saver = tf.train.Saver()
    
with tf.Session() as sess:
    #valid_batch,valid_label_batch=sess.run([valid_batch,valid_label_batch])
    sess.run(tf.global_variables_initializer())
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#     variables = tf.contrib.framework.get_variables_to_restore()
#     variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='softmax_linear']

    saver.restore(sess,ckpt_path)
    print('--------restore done---------')
    summary_op = tf.summary.merge_all() 

    valid_writer = tf.summary.FileWriter(logs_valid_dir, sess.graph)

    for step in np.arange(MAX_STEP):  
        try:
   
            for step in np.arange(MAX_STEP):
                if coord.should_stop():
                        break

                print('---------validation start----------')
                val_acc, val_recall, val_precision = sess.run([valid__acc, val_recall, val_precision])
                #val_acc, cal_recall, val_precision = sess.run([valid__acc, val_recall, val_precision],feed_dict={x:valid_batch,y:valid_label_batch})
                print('Step %d, valid accuracy = %.2f%%, valid recall = %.2f%%, valid precision = %.2f%%' %(step, val_acc*100.0, val_recall*100.0, val_precision*100.0))

                summary_str = sess.run(summary_op)
                valid_writer.add_summary(summary_str, step)

        except tf.errors.OutOfRangeError:
            print('Done validation -- epoch limit reached')

        finally:
            coord.request_stop()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值