搭建简单图片分类的卷积神经网络(三)-- 模型的测试和运用

两个功能都在同一个文件中

一、新建Disimage.py文件

import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN

classes = []
n_classes = 0

#获取一张图片
def get_one_image(train):
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]  # 随机选择测试的图片

    # img_data = Image.open(img_dir)
    imag = Image.open(img_dir)
    imag = imag.resize([64, 64])  # 由于图片在预处理阶段以及resize,因此该命令可略
    image = np.array(imag)
    return image

def evaluate_one_image(image_array,N_CLASSES):
    with tf.Graph().as_default():
        BATCH_SIZE = 1

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = r'E:\PycharmPython\NewCnn\logs'

        saver = tf.train.Saver()

        with tf.Session() as sess:
            print('Reading checkpoints...')
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit,feed_dict={x:image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print('This is a animales with possibility %.6f' % prediction[:, 0])
            elif max_index == 1:
                print('This is a banded with possibility %.6f' % prediction[:, 1])
            elif max_index == 2:
                print('This is a potholed with possibility %.6f' % prediction[:, 2])
            elif max_index == 3:
                print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
            else:
                print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
    return max_index

if __name__ == '__main__':
    train_dir = r'E:\PycharmPython\NewCnn\train\train_data'  #训练集路径

    for str_classes in os.listdir(train_dir):
        classes.append(str_classes)
        n_classes =n_classes + 1

    train, train_label, val, val_label = get_files(train_dir, 0.3)
    img = get_one_image(val)  # 通过改变参数train or val,进而验证训练集或测试集
    pre = evaluate_one_image(img,n_classes)

上面是对之前已经处理好图片划分好测试集,进行测试的。

二、将代码改成

import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN


classes = []
n_classes = 0

#对预测后图片路径的处理
def prediction_image_path(Classes,dir):
    for index,name in enumerate(Classes):
        prediction_path = dir +'\\' + name   #判断是否有文件夹
        folder = os.path.exists(prediction_path)
        if not folder :
            os.makedirs(prediction_path)  #创建文件夹
            print(prediction_path,'new file')
        else:
            for str_image in os.listdir(prediction_path):
                prediction_image_path = prediction_path + '\\'+str_image
                os.remove(prediction_image_path)   #清空文件夹
            print('There is this flie')

#获取一张图片
def get_one_image(train):
    # n = len(train)
    # ind = np.random.randint(0, n)
    # img_dir = train[ind]  # 随机选择测试的图片

    img_data = Image.open(train)
    imag = Image.open(train).convert('RGB')
    imag = imag.resize([64, 64])  # 由于图片在预处理阶段以及resize,因此该命令可略
    image = np.array(imag)
    return img_data,image


def evaluate_one_image(image_array,N_CLASSES):
    with tf.Graph().as_default():
        BATCH_SIZE = 1

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = r'E:\PycharmPython\NewCnn\logs'

        saver = tf.train.Saver()

        with tf.Session() as sess:
            print('Reading checkpoints...')
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit,feed_dict={x:image_array})
            max_index = np.argmax(prediction)
            # if max_index == 0:
            #     print('This is a animales with possibility %.6f' % prediction[:, 0])
            # elif max_index == 1:
            #     print('This is a banded with possibility %.6f' % prediction[:, 1])
            # elif max_index == 2:
            #     print('This is a potholed with possibility %.6f' % prediction[:, 2])
            # elif max_index == 3:
            #     print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
            # else:
            #     print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
    return max_index
        # print(max_index)

if __name__ == '__main__':
    train_dir = r'E:\PycharmPython\NewCnn\train\train_data'  #训练集路径
    image_dir = r'E:\PycharmPython\NewCnn\image'   #待分类图片路径
    prediction_dir = r'E:\PycharmPython\NewCnn\prediction'  #分类结果存储路径
    for str_classes in os.listdir(train_dir):
        classes.append(str_classes)
        n_classes =n_classes + 1

    # #创建分类后图片的存储路径
    # train, train_label, val, val_label = get_files(train_dir, 0.3)
    # img = get_one_image(val)  # 通过改变参数train or val,进而验证训练集或测试集
    # pre = evaluate_one_image(img,n_classes)
    prediction_image_path(classes,prediction_dir)
    #扫描待分类图片,分类之后存储到对应的分类路径
    for image_data in  os.listdir(image_dir):
        image_data_path = image_dir + '\\'+image_data
        orig_img,img = get_one_image(image_data_path)
        pre = evaluate_one_image(img,n_classes)
        for i in range(n_classes):
            if pre == i:
                print(classes[i])
                orig_img.save(prediction_dir +'\\'+ classes[i] +'\\' +str(i) + image_data+ '.jpg')

上面是对image文件中图片进行分类。

连载:https://blog.csdn.net/qq_28821995/article/details/83587032

https://blog.csdn.net/qq_28821995/article/details/83587530

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值