两个功能都在同一个文件中
一、新建Disimage.py文件
import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN
classes = []
n_classes = 0
#获取一张图片
def get_one_image(train):
n = len(train)
ind = np.random.randint(0, n)
img_dir = train[ind] # 随机选择测试的图片
# img_data = Image.open(img_dir)
imag = Image.open(img_dir)
imag = imag.resize([64, 64]) # 由于图片在预处理阶段以及resize,因此该命令可略
image = np.array(imag)
return image
def evaluate_one_image(image_array,N_CLASSES):
with tf.Graph().as_default():
BATCH_SIZE = 1
image = tf.cast(image_array, tf.float32)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 64, 64, 3])
logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)
logit = tf.nn.softmax(logit)
x = tf.placeholder(tf.float32, shape=[64, 64, 3])
logs_train_dir = r'E:\PycharmPython\NewCnn\logs'
saver = tf.train.Saver()
with tf.Session() as sess:
print('Reading checkpoints...')
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
prediction = sess.run(logit,feed_dict={x:image_array})
max_index = np.argmax(prediction)
if max_index == 0:
print('This is a animales with possibility %.6f' % prediction[:, 0])
elif max_index == 1:
print('This is a banded with possibility %.6f' % prediction[:, 1])
elif max_index == 2:
print('This is a potholed with possibility %.6f' % prediction[:, 2])
elif max_index == 3:
print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
else:
print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
return max_index
if __name__ == '__main__':
train_dir = r'E:\PycharmPython\NewCnn\train\train_data' #训练集路径
for str_classes in os.listdir(train_dir):
classes.append(str_classes)
n_classes =n_classes + 1
train, train_label, val, val_label = get_files(train_dir, 0.3)
img = get_one_image(val) # 通过改变参数train or val,进而验证训练集或测试集
pre = evaluate_one_image(img,n_classes)
上面是对之前已经处理好图片划分好测试集,进行测试的。
二、将代码改成
import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN
classes = []
n_classes = 0
#对预测后图片路径的处理
def prediction_image_path(Classes,dir):
for index,name in enumerate(Classes):
prediction_path = dir +'\\' + name #判断是否有文件夹
folder = os.path.exists(prediction_path)
if not folder :
os.makedirs(prediction_path) #创建文件夹
print(prediction_path,'new file')
else:
for str_image in os.listdir(prediction_path):
prediction_image_path = prediction_path + '\\'+str_image
os.remove(prediction_image_path) #清空文件夹
print('There is this flie')
#获取一张图片
def get_one_image(train):
# n = len(train)
# ind = np.random.randint(0, n)
# img_dir = train[ind] # 随机选择测试的图片
img_data = Image.open(train)
imag = Image.open(train).convert('RGB')
imag = imag.resize([64, 64]) # 由于图片在预处理阶段以及resize,因此该命令可略
image = np.array(imag)
return img_data,image
def evaluate_one_image(image_array,N_CLASSES):
with tf.Graph().as_default():
BATCH_SIZE = 1
image = tf.cast(image_array, tf.float32)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 64, 64, 3])
logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)
logit = tf.nn.softmax(logit)
x = tf.placeholder(tf.float32, shape=[64, 64, 3])
logs_train_dir = r'E:\PycharmPython\NewCnn\logs'
saver = tf.train.Saver()
with tf.Session() as sess:
print('Reading checkpoints...')
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
prediction = sess.run(logit,feed_dict={x:image_array})
max_index = np.argmax(prediction)
# if max_index == 0:
# print('This is a animales with possibility %.6f' % prediction[:, 0])
# elif max_index == 1:
# print('This is a banded with possibility %.6f' % prediction[:, 1])
# elif max_index == 2:
# print('This is a potholed with possibility %.6f' % prediction[:, 2])
# elif max_index == 3:
# print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
# else:
# print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
return max_index
# print(max_index)
if __name__ == '__main__':
train_dir = r'E:\PycharmPython\NewCnn\train\train_data' #训练集路径
image_dir = r'E:\PycharmPython\NewCnn\image' #待分类图片路径
prediction_dir = r'E:\PycharmPython\NewCnn\prediction' #分类结果存储路径
for str_classes in os.listdir(train_dir):
classes.append(str_classes)
n_classes =n_classes + 1
# #创建分类后图片的存储路径
# train, train_label, val, val_label = get_files(train_dir, 0.3)
# img = get_one_image(val) # 通过改变参数train or val,进而验证训练集或测试集
# pre = evaluate_one_image(img,n_classes)
prediction_image_path(classes,prediction_dir)
#扫描待分类图片,分类之后存储到对应的分类路径
for image_data in os.listdir(image_dir):
image_data_path = image_dir + '\\'+image_data
orig_img,img = get_one_image(image_data_path)
pre = evaluate_one_image(img,n_classes)
for i in range(n_classes):
if pre == i:
print(classes[i])
orig_img.save(prediction_dir +'\\'+ classes[i] +'\\' +str(i) + image_data+ '.jpg')
上面是对image文件中图片进行分类。
连载:https://blog.csdn.net/qq_28821995/article/details/83587032