import tensorflow as tf
import numpy as np
import random
import os
import math
from matplotlib import pyplot as plt
def get_files(file_dir):
"""
创建数据文件名列表
:param file_dir:
:return:image_list 所有图像文件名的列表,label_list 所有对应标贴的列表
"""
#step1.获取图片,并贴上标贴
#新建五个列表,存储文件夹下的文件名
daisy=[]
label_daisy=[]
dandelion=[]
label_dandelion = []
roses=[]
label_roses = []
sunflowers=[]
label_sunflowers = []
tulips=[]
label_tulips = []
for file in os.listdir(file_dir+"/daisy"):
daisy.append(file_dir+"/daisy"+"/"+file)
label_daisy.append(0)
for file in os.listdir(file_dir+"/dandelion"):
dandelion.append(file_dir+"/dandelion"+"/"+file)
label_dandelion.append(1)
for file in os.listdir(file_dir+"/roses"):
roses.append(file_dir+"/roses"+"/"+file)
label_roses.append(2)
for file in os.listdir(file_dir+"/sunflowers"):
sunflowers.append(file_dir+"/sunflowers"+"/"+file)
label_sunflowers.append(3)
for file in os.listdir(file_dir+"/tulips"):
tulips.append(file_dir+"/tulips"+"/"+file)
label_tulips.append(4)
#step2:对生成的图片路径和标签List做打乱处理
#把所有图片跟标贴合并到一个列表list(img和lab)
images_list=np.hstack([daisy,dandelion,roses,sunflowers,tulips])
labels_list=np.hstack([label_daisy,label_dandelion,label_roses,label_sunflowers,label_tulips])
#利用shuffle打乱顺序
temp=np.array([images_list,labels_list]).transpose()
np.random.shuffle(temp)
# 从打乱的temp中再取出list(img和lab)
image_list=list(temp[:,0])
label_list=list(temp[:,1])
label_list_new=[int(i) for i in label_list]
# 将所得List分为两部分,一部分用来训练tra,一部分用来测试val
# 测试样本数, ratio是测试集的比例
ratio=0.3
n_sample = len(label_list)
n_val = int(math.ceil(n_sample * ratio))
n_train = n_sample - n_val # 训练样本数
tra_images = image_list[0:n_train]
tra_labels = label_list_new[0:n_train]
#tra_labels = [int(float(i)) for i in tra_labels] # 转换成int数据类型
val_images = image_list[n_train:-1]
val_labels = label_list_new[n_train:-1]
#val_labels = [int(float(i)) for i in val_labels] # 转换成int数据类型
return tra_images, tra_labels, val_images, val_labels
#return image_list,label_list_new
def get_batch(image, label, image_W, image_H,channel, batch_size, capacity):
#step1:将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue
#类型转换
image=tf.cast(image,tf.string)
label=tf.cast(label,tf.int32)
#生成输入队列
input_queue=tf.train.slice_input_producer([image,label])
label=input_queue[1]
image_contents=tf.read_file(input_queue[0])
#print(image_contents)
#step2:将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等
images_value=tf.image.decode_jpeg(image_contents)
#print(images_value)
#step3:数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮
#image=tf.image.resize_image_with_crop_or_pad(images_value,image_W,image_H)
#image=tf.image.resize_images(images_value,size=[200,200])
image = tf.image.resize_images(images_value,size=[image_W,image_H])
#image.set_shape(shape=[200, 200, 3])
image.set_shape(shape=[image_W, image_H, channel])
#print(image)
# 对resize后的图片进行标准化处理
image=tf.image.per_image_standardization(image)
#step4:生成batch
image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,num_threads=1,capacity=capacity)
# 重新排列label,行数为[batch_size]
#print(label_batch)
label_batch = tf.reshape(label_batch, [batch_size])
#print(label_batch)
image_batch = tf.cast(image_batch, tf.float32)
return image_batch,label_batch
if __name__=="__main__":
BATCH_SIZE = 2
CAPACITY = 256
IMG_W = 208
IMG_H = 208
# 读取文件所在路径
mypath = "/home/sunxiaoming/PycharmProjects/data/flower_photos"
image_list,label_list=get_files(mypath)
print(len(image_list))
print(len(label_list))
image_batch,label_batch=get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
print(image_batch)
with tf.Session() as sess:
# 开启线程
# 线程协调元
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
i=0
while i<2:
image,lable = sess.run([image_batch, label_batch])
#image_array=np.array(image[i,:,:,:])
for j in range(2):
plt.imshow(image[j, :, :, :])
plt.show()
i+=1
# 回收线程
coord.request_stop()
coord.join(threads)
#with tf.Session() as sess:
# 开启线程
# 线程协调元
#coord = tf.train.Coordinator()
#threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#i=0
#while not coord.should_stop() and i < 2:
#lable, image = sess.run([image_batch,label_batch])
#print(type(image))
#"""
#for j in np.arange(BATCH_SIZE):
# print('label: %d' % lable[j])
#plt.imshow(image[j, :, :, :])
#plt.show()
#i += 1
#"""
# 回收线程
#coord.request_stop()
#coord.join(threads)