tensorflow-slim非量化分类图片模型测试

该博客演示了如何使用TensorFlow加载预训练的ResNet_v2_152模型,对图片进行分类。通过配置GPU选项,加载指定路径的检查点文件,对测试图片进行解码、预处理,并获取预测标签。最后,根据预测标签输出对应的类别名称。
摘要由CSDN通过智能技术生成
import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
slim = tf.contrib.slim

def main(_):
    checkpoint_path = './log/'
    test_path = './test/5_4.jpg'
    num_classes = 6
    model_name = 'resnet_v2_152'
    preprocessing_name = None
    test_image_size = None
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()
        # Select the model
        network_fn = nets_factory.get_network_fn(
            model_name,
            num_classes,
            is_training=False)
        # Select the preprocessing function
        preprocessing_name = preprocessing_name or model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=False)

        test_image_size = test_image_size or network_fn.default_image_size

        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path

        tf.Graph().as_default()
        with tf.Session() as sess:
            image = open(test_path, 'rb').read()
            image = tf.image.decode_jpeg(image, channels=3)
            processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
            processed_images = tf.expand_dims(processed_image, 0)
            logits, _ = network_fn(processed_images)

            predictions = tf.argmax(logits, 1)
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)
            np_image, network_input, predictions = sess.run([image, processed_image, predictions])
            a = logits.eval()

            if predictions[0] == 0:
                label = "chunbai"
            if predictions[0] == 1:
                label = "chunhei"
            if predictions[0] == 2:
                label = "hongse"
            if predictions[0] == 3:
                label = "huierxian or huishiban"
            if predictions[0] == 4:
                label = "huiyudian or huipendian"
            if predictions[0] == 5:
                label = "qilinhua"
            print('{} {}'.format(test_path, label))
if __name__ == '__main__':
    tf.app.run()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王二小、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值