在开始之前,笔者想简单介绍一下tensorflow程序读取文件的三种方式:
(1):供给数据(Feeding):在TensorFlow程序运行的每一步, 让Python代码来供给数据。
(2):从文件读取数据:在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
(3):预加载数据:在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
开始正文啦:
关于tf.train.batch和tf.train.string_input_producer的区别:
(1)输入所需图片的地址,然后放到tf.train.string_input_producer中进行管理,注意tf.train.string_input_producer中只是图片的的地址,不是图片的值。
(2)然后用各种读取器读取地址中的数据(图片,标签),用的是reader=tf.FixedLengthRecordReader(record_bytes=record_bytes),按字节来读的。还有readertf.WholeFileReader()这是直接读完的。
我们已经把数据放入了文件名队列中啦,剩下的就是两部分:一个是读取数据:result.key,value=reader.read(filename_queue);一个是解析数据:record_bytes=tf.decode_raw(value,tf.uint8)
(3)为了增强图像的稳健性,可以对图像进行一系列的操作,比如旋转、亮度、对比度、裁剪等操作,对了,一定要注意标准化,不然准确率上不去呀!!!
(4)此时的图像已经是tensor形式了,然后就是放入tf.train.batch或tf.train.shuffle_batch中进行管理,以便数据提取,这里才引入线程的。
image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,num_threads=num_preprocess_threads,capacity=min_queue_example+3*batch_size)
总结:TensorfFlow将数据读取分为两步骤哦,先读入文件名队列中,以生成tensor,再读入内存队列进行运算。在tf中有三个函数tf.train.string_input_producer(输入为一个string类型的tensor列表),tf,train.slice_input_producer(输入为一个tensor列表),tf.train.input_producer(输入为一个tensor)用于建立文件名队列。
下面的图很形象:
最后笔者要放上代码啦:
def get_files(filename):
class_train = []
label_train = []
for train_class in os.listdir(filename):
for pic in os.listdir(filename+train_class):
class_train.append(filename+train_class+'/'+pic)
label_train.append(train_class)
temp = np.array([class_train,label_train])
temp = temp.transpose()
#shuffle the samples
np.random.shuffle(temp)
#after transpose, images is in dimension 0 and label in dimension 1
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(i) for i in label_list]
#print(label_list)
return image_list,label_list
def get_batches(image,label,resize_w,resize_h,batch_size,capacity):
#convert the list of images and labels to tensor
image = tf.cast(image,tf.string)
label = tf.cast(label,tf.int64)
queue = tf.train.slice_input_producer([image,label])
label = queue[1]
image_c = tf.read_file(queue[0])
image = tf.image.decode_jpeg(image_c,channels = 3)
#resize
image=tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h)
#(x - mean) / adjusted_stddev
image = tf.image.per_image_standardization(image)
image_batch,label_batch = tf.train.batch([image,label],
batch_size = batch_size,
num_threads = 64,
capacity = capacity)
images_batch = tf.cast(image_batch,tf.float32)
labels_batch = tf.reshape(label_batch,[batch_size])
return images_batch,labels_batch
好了,肯定不全面,就见谅啦,拜拜啦,see you again!小可爱们。。。