tensorflow 批量读取多个csv文件
#!/usr/bin/python
# -*- coding:utf-8 -*-
import tensorflow as tf
import os
def csvfile(fileist):
file_queue=tf.train.string_input_producer(filelist)
reader=tf.TextLineReader()
key,value=reader.read(file_queue)
records=[['None'],['None']]
example,label=tf.decode_csv(value,record_defaults=records)
example_batch,label_batch=tf.train.batch([example,label],batch_size=9,num_threads=1,capacity=9)
return example_batch,label_batch
pass
if __name__ == '__main__':
listname=os.listdir("../data")
print(listname)
file_name="F:/window7x64/software/Pycharm/pyproject27/pythonml/com/itheima/deeplearning/data/"
filelist=[os.path.join(file_name,line) for line in listname]
print(filelist)
example_batch, label_batch=csvfile(filelist)
with tf.Session() as sess:
# 定义线程协调器
coord=tf.train.Coordinator()
# 开启读文件线程
thread=tf.train.start_queue_runners(sess,coord=coord)
print(sess.run([example_batch,label_batch]))
# 回收子线程
coord.request_stop()
coord.join(thread)
pass