前言
之前跑了一个小的CNN,因为训练数据量也不大,自己的电脑显卡2g,直接将数据读到内存中再取batch来训练还能直接跑;后来数据集大了一些,再用这种方法就不行了,内存不足。没办法,找了一些TFRecord和Dataset的资料把程序改了改,这样就可以边取数据边训练,总算能跑了,以后再也不用怕训练数据多啦。
数据
我的训练数据是一些元素为float类型的矩阵,用numpy生成一些来代表吧:
data = np.random.rand(5,10,10)
这样就生成了shape为(5,10,10),元素类型为float的矩阵data了。意义为5个大小为(10,10)的矩阵。
生成TFRecord文件
直接上代码:
import numpy as np
import tensorflow as tf
data = np.random.rand(5,10,10)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 将数据转化为tf.train.Example格式。
def _make_example(data_raw):
data_raw = data.tostring()# 先将元素为float类型的numpy数组转换为string
example = tf.train.Example(features=tf.train.Features(feature={
'data_raw': _bytes_feature(data_raw)
}))
return example
num_examples = data.shape[0]
# 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("output.tfrecords") as writer:
for index in range(num_examples):
example = _make_example(data[index])
writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。")
运行完后就会在当前目录下生成output.tfrecords的文件了,这就是我们的数据。虽然感觉不好理解,其实也就这几句代码就完成了TFRecord数据的生成。
因为TFRecord只能存储BytesList、FloatList、Int64List,所以矩阵不能直接存储,可以将矩阵转换成字符串来存储,如以上程序data_raw = data.tostring()所示。也可以将矩阵flat成一维即FloatList来存储,读出时再reshape转换为之前的格式,按理可行,但我没试。
使用Dataset恢复数据
def parser(record1):
features = tf.parse_single_example(
record1,
features={
'data_raw':tf.FixedLenFeature([],tf.string),
})
data_raw = features['data_raw']
return data_raw
batch_size = 2
dataset = tf.data.TFRecordDataset("output.tfrecords")
dataset = dataset.map(parser)
dataset = dataset.batch(batch_size)
NUM_EPOCHS = 2
dataset = dataset.repeat(NUM_EPOCHS)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
x_batch = []
sess.run((tf.global_variables_initializer(),
tf.local_variables_initializer()))
sess.run(iterator.initializer)
while True:
try:
x_train = sess.run(next_element) # 取出设定的
for i in range(x_train.shape[0]):
x_batch.append(list(np.fromstring(x_train[i],np.float32)))
print(x_batch)
except tf.errors.OutOfRangeError:
break
x_batch就是恢复出来的数据了,是list列表,根据需要可以转换成其他的形式。
parser解析函数的作用是把数据从TFRecord文件中解析出来。
map(parser)表示对数据集中的每一条数据调用parser方法。
dataset.batch函数是设置一次取出的batch的大小。
dataset.repeat函数表示取全部数据几个epoch,比如我这里设置为2,那么在取完两次全部数据后,就会抛出OutOfRangeError异常,可以以此来停止程序循环。
dataset.shuffle函数将数据集打乱顺序,这里没有展示。
每运行一次sess.run(next_element)就会从TFRecord文件中取出设定的batch大小的数据。取出来的数据的格式是字符串,通过np.fromstring()函数转换成真正的数据。这里要注意转换成的float类型,你原来的数据是float32你转换时就要转换为float32,你如果转换成float64,那肯定数据转换出来是错的。
未理解的地方
取出一个batch的数据之后,如果直接整体操作:
x_train = np.fromstring(x_train,np.float32)
转换出来的数据是不对的,必须要一条一条的转换出来,然后再组合起来才行。