Tensorflow文件读取
文件读取流程
如果读取ABC三个文件 :
1.将文件和文件路径乱序或者顺序的放入队列当中,构造一个文件队列。
- 构造文件阅读器,读取队列内容,应为文件格式繁多,默认读取一个样本比如图片文件:按一张一张读取,cv文件:读取一行,二进制文件:指定一个样本的bytes读取。需要多次读取
3.进行解码操作转换(decode),转换出来的也是一个样本。
4.批量处理,比如读取了A样本中的50个样本,然后放入队列,这就是批处理。
主线程要做:取样本数据训练,子线程处理1,2,3,4步骤
1、文件读取API-文件队列构造
tf.train.string_input_producer(string_tensor,shuffle=True)
将输出字符串(例如文件名)输入到管道队列
string_tensor 含有文件名的1阶张量
num_epochs:过几遍数据,默认无限过数据
return:具有输出字符串的队列
2、文件读取API-文件阅读器
根据文件格式,选择对应的文件阅读器
class tf.TextLineReader
阅读文本文件逗号分隔值(CSV)格式,默认按行读取
return:读取器实例
tf.FixedLengthRecordReader(record_bytes)
要读取每个记录是固定数量字节的二进制文件
record_bytes:整型,指定每次读取的字节数
return:读取器实例
tf.TFRecordReader
读取TfRecords文件
有一个共同的读取方法:
read(file_queue):从队列中指定数量内容
返回一个Tensors元组(key文件名字,value默认的内容(行,字节))
3、文件读取API-文件内容解码器
由于从文件中读取的是字符串,需要函数去解析这些字符串到张量
tf.decode_csv(records,record_defaults=None,field_delim = None,name = None)
将CSV转换为张量,与tf.TextLineReader搭配使用
records:tensor型字符串,每个字符串是csv中的记录行
field_delim:默认分割符”,”
record_defaults:参数决定了所得张量的类型,并设置一个值在输入字符串中缺少使用默认值
tf.decode_raw(bytes,out_type,little_endian = None,name = None)
将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用,二进制读取为uint8格式
CSV文件读取
1.先找到文件,构造一个列表
2.构造文件队列
3.构造阅读器,读取队列内容
4.解码内容
5.批处理(多个样本)
开启线程操作:
tf.train.start_queue_runners(sess=None,coord=None)
收集所有图中的队列线程,并启动线程
sess:所在的会话中
coord:线程协调器
return:返回所有线程队列
4.管道读端批处理
批处理大小跟队列,数据的数量没有影响,只决定 这批次取多少数据
tf.train.batch(tensors,batch_size,num_threads = 1,capacity = 32,name=None)
读取指定大小(个数)的张量
tensors:可以是包含张量的列表
batch_size:从队列中读取的批处理大小
num_threads:进入队列的线程数
capacity:整数,队列中元素的最大数量
return:tensors
tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue, num_threads=1,)
乱序读取指定大小(个数)的张量
min_after_dequeue:留下队列里的张量个数,能够保持随机打乱
import tensorflow as tf
import os
# 批处理大小跟队列,数据的数量没有影响 只决定 这批次取多少数据
def csvread(filelist):
"""读取csv文件"""
#1.构造文件的队列
file_queue=tf.train.string_input_producer(filelist)
#2.构造阅读器读取队列数据(一行)
reader=tf.TextLineReader()
key,value=reader.read(file_queue)
#3.对每行内容解码
#record_defaults 指定每一个样本的每一列的类型或者指定默认值
records=[["None"],["None"],["float"],["float"]]
first,second,three,four=tf.decode_csv(value,record_defaults=records)
print(first,second,three,four)
#4.想要读取多个数据就需要进行批处理
first_batch,second_batch,three_batch,four_batch=tf.train.batch([first,second,three,four],batch_size=10,num_threads=1,capacity=10)
print(first_batch,second_batch,three_batch,four_batch)
return first_batch,second_batch,three_batch,four_batch
if __name__ == '__main__':
#1.找到文件,放入列表 路径+名字 放入列表
file_name=os.listdir("./csv文件/")
file_list=[os.path.join("./csv文件/",file) for file in file_name]
first_batch,second_batch,three_batch,four_batch=csvread(file_list)
#开启会话运行
with tf.Session() as sess:
#定义一个线程协调器
coord=tf.train.Coordinator()
#开启读取文件的线程
threads=tf.train.start_queue_runners(sess,coord=coord)
#打印读取内容
print(sess.run([first_batch,second_batch,three_batch,four_batch]))
#回收子线程
coord.request_stop()
coord.join(threads)