Tensorflow中CSV文件数据读取主要步骤:
1、找到文件 构造文件列表
2、构建一个文件队列
3、构建文件阅读器 读取队列内容(一行)
4、文件解码
5、批处理读取大量数据
代码实现如下:
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
tf = tf.compat.v1
def read_csvfile(filelist):
'''
读取CSV文件
:param file_list:文件名列表
:return:
'''
# 1构造文件队列
filequeue = tf.train.string_input_producer(filelist)
# 2构造csv阅读器读取
reader = tf.TextLineReader()
key,value = reader.read(filequeue)
# print(key,value)
# 3解码文件
records = [["None"],["None"]]
example,label = tf.decode_csv(value,record_defaults=records) #record_defaults指定每一个样本每一列的类型
# print(example,label)
# 4批处理读取多个数据
example_batch, label_batch = tf.train.batch([example, label], batch_size=10, num_threads=2, capacity=9)
# print(example_batch, label_batch)
return example_batch,label_batch
if __name__ == "__main__":
# 0获取文件 构建文件列表
# 获取文件名
filenames = os.listdir("../data/csvdata/")
# print(filenames)
#拼接路径形成完整文件名称
filelist = [os.path.join("../data/csvdata",filename) for filename in filenames]
# print(filelist)
example_batch,label_batch = read_csvfile(filelist)
# 开启会话,运行结果
with tf.Session() as sess:
# 定义一个线程协调器
coord = tf.train.Coordinator()
# 开启读取文件的线程
filereadthreads = tf.train.start_queue_runners(sess,coord=coord)
# 输出读取的内容
print(sess.run([example_batch,label_batch]))
# 回收子线程
coord.request_stop()
coord.join(filereadthreads)