- 模型测试
模型训练好之后通过重新加载模型的方式进行模型测试,使用Tensorflow中的Saver对象。相关代码如下:
主函数为:def test_cnn(x_data): output = create_cnn(4) saver = tf.train.Saver() with tf.Session() as sess: #加载训练好的模型 saver.restore(sess,"./model/cnn.model-2100") preject = tf.argmax(output,1) x_in = np.array(x_data) #keep_prob设置为1 label = sess.run(preject,feed_dict={X:[x_in],keep_prob:1}) return label
if __name__ == '__main__': isTrain = 2 if 1 == isTrain: X = tf.placeholder(tf.float32,[None,200,150,3]) Y = tf.placeholder(tf.float32,[None,4]) keep_prob = tf.placeholder(tf.float32) train_cnn(xdata,ydata) if 2 == isTrain: #将测试数据放在相应的文件中 path_list = ['./0','./1','./2','./3'] for p in path_list: file_info = os.listdir(p) for file_name in file_info: x_data = read_test_data(p+'/'+file_name) if type(x_data) == type(None): print('==>',p) continue #没有这句,会出现问题 tf.reset_default_graph() X = tf.placeholder(tf.float32,[None,200,150,3]) keep_prob = tf.placeholder(tf.float32) l = test_cnn(x_data) label = ['girl','beauty girl','boy','handsome boy'] print(p,':',file_name,'====>',label[l[0]])
卷积神经网络简单的应用(三):模型测试
最新推荐文章于 2024-07-22 10:57:33 发布