使用Tensorflow制作球鞋识别模型(五)——模型测试

在上一篇博客使用Tensorflow制作球鞋识别模型(四)——训练模型中完成了对模型的训练,整个球鞋识别模型就已经搭建完毕,最后工作将是对模型进行测试。



模型测试

导入包

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import model
from batch import get_files

获取图片

# 获取一张图片
# 输入参数:train,训练图片的路径
# 返回参数:image,从训练图片中随机抽取一张图片
def get_one_image(train):
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]  # 随机选择测试的图片

    img = Image.open(img_dir)
    plt.imshow(img)
    
    # 如果不加plt.show(img)则无法显示图片,只显示结果
    # plt.imshow()函数负责对图像进行处理,并显示其格式,而plt.show()则是将plt.imshow()处理后的函数显示出来。
    plt.show(img)
    
    # 由于图片在预处理阶段以及resize,因此该命令可略
    imag = img.resize([64, 64])  
    image = np.array(imag)
    return image


测试图片

# 测试图片
def evaluate_one_image(image_array):
    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 4

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = model.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])

        logs_train_dir = 'D:/PyCharm/PycharmProjects/AJ_Recognition/train_log'

        saver = tf.train.Saver()

        with tf.Session() as sess:

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print('This is a Air Jordan I with possibility %.6f' % prediction[:, 0])
            elif max_index == 1:
                print('This is a Air Jordan IV with possibility %.6f' % prediction[:, 1])
            elif max_index == 2:
                print('This is a Air Jordan XI with possibility %.6f' % prediction[:, 2])
            else:
                print('This is a Air Jordan XII with possibility %.6f' % prediction[:, 3])

if __name__ == '__main__':
    train_dir = 'D:/PyCharm/PycharmProjects/AJ_Recognition/data_prepare/pic/inputdata/'
    train, train_label, val, val_label = get_files(train_dir, 0.3)
    # 通过改变参数train or val,进而验证训练集或测试集
    img = get_one_image(val) 
    evaluate_one_image(img)


测试结果

提示:需要将显示图片的窗口关闭后才会显示结果。
在这里插入图片描述

在这里插入图片描述


全部代码

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import model
from batch import get_files


# 获取一张图片
# 输入参数:train,训练图片的路径
# 返回参数:image,从训练图片中随机抽取一张图片
def get_one_image(train):
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]  # 随机选择测试的图片

    img = Image.open(img_dir)
    plt.imshow(img)
    # 如果不加plt.show(img)则无法显示图片,只显示结果
    # plt.imshow()函数负责对图像进行处理,并显示其格式,而plt.show()则是将plt.imshow()处理后的函数显示出来。
    plt.show(img)
    # 由于图片在预处理阶段以及resize,因此该命令可略
    imag = img.resize([64, 64])  
    image = np.array(imag)
    return image


# 测试图片
def evaluate_one_image(image_array):
    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 4

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = model.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])

        logs_train_dir = 'D:/PyCharm/PycharmProjects/AJ_Recognition/train_log'

        saver = tf.train.Saver()

        with tf.Session() as sess:

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print('This is a Air Jordan I with possibility %.6f' % prediction[:, 0])
            elif max_index == 1:
                print('This is a Air Jordan IV with possibility %.6f' % prediction[:, 1])
            elif max_index == 2:
                print('This is a Air Jordan XI with possibility %.6f' % prediction[:, 2])
            else:
                print('This is a Air Jordan XII with possibility %.6f' % prediction[:, 3])



if __name__ == '__main__':
    train_dir = 'D:/PyCharm/PycharmProjects/AJ_Recognition/data_prepare/pic/inputdata/'
    train, train_label, val, val_label = get_files(train_dir, 0.3)
    # 通过改变参数train or val,进而验证训练集或测试集
    img = get_one_image(val)  
    evaluate_one_image(img)


项目代码

GitHub地址:https://github.com/WellTung666/Tensorflow/tree/master/AJ_Recognition


参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值