tensorflow要想用起来,首先自己得搞定数据输入。官方文档中介绍了几种,1.一次性从内存中读取数据到矩阵中,直接输入;2.从文件中边读边输入,而且已经给设计好了多线程读写模型;3.把网络或者内存中的数据转化为tensorflow的专用格式tfRecord,存文件后再读取。
其中,从文件中边读边输入,官方文档举例是用的CSV格式文件。我在网上找了一份代码,修改了一下,因为他的比较简略,我就补充一下遇到的问题
先贴代码
#coding=utf-8import tensorflow as tf
import numpy as np
defreadMyFileFormat(fileNameQueue):
reader = tf.TextLineReader()
key, value = reader.read(fileNameQueue)
record_defaults = [[1], [1], [1]]
col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults)
features = tf.pack([col1, col2])
label = col3
return features, label
definputPipeLine(fileNames = ["1.csv","2.csv"], batchSize =4, numEpochs = None):
fileNameQueue = tf.train.string_input_producer(fileNames, num_epochs = numEpochs)
example, label = readMyFileFormat(fileNameQueue)
min_after_dequeue =8
capacity = min_after_dequeue +