tf.data.Dataset读取TFRecord文件进行训练

前言

之前跑了一个小的CNN,因为训练数据量也不大,自己的电脑显卡2g,直接将数据读到内存中再取batch来训练还能直接跑;后来数据集大了一些,再用这种方法就不行了,内存不足。没办法,找了一些TFRecord和Dataset的资料把程序改了改,这样就可以边取数据边训练,总算能跑了,以后再也不用怕训练数据多啦。

数据

我的训练数据是一些元素为float类型的矩阵,用numpy生成一些来代表吧:

data = np.random.rand(5,10,10)

这样就生成了shape为(5,10,10),元素类型为float的矩阵data了。意义为5个大小为(10,10)的矩阵。

生成TFRecord文件

直接上代码:

import numpy as np
import tensorflow as tf

data = np.random.rand(5,10,10)

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 将数据转化为tf.train.Example格式。
def _make_example(data_raw):
    data_raw = data.tostring()# 先将元素为float类型的numpy数组转换为string
    example = tf.train.Example(features=tf.train.Features(feature={
        'data_raw': _bytes_feature(data_raw)
    }))
    return example

num_examples = data.shape[0]

# 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("output.tfrecords") as writer:
    for index in range(num_examples):
        example = _make_example(data[index])
        writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。")

运行完后就会在当前目录下生成output.tfrecords的文件了,这就是我们的数据。虽然感觉不好理解,其实也就这几句代码就完成了TFRecord数据的生成。
因为TFRecord只能存储BytesList、FloatList、Int64List,所以矩阵不能直接存储,可以将矩阵转换成字符串来存储,如以上程序data_raw = data.tostring()所示。也可以将矩阵flat成一维即FloatList来存储,读出时再reshape转换为之前的格式,按理可行,但我没试。

使用Dataset恢复数据

def parser(record1):
	features = tf.parse_single_example(
    record1,
    features={
        'data_raw':tf.FixedLenFeature([],tf.string),
    })
	data_raw = features['data_raw']
	return data_raw

batch_size = 2
dataset = tf.data.TFRecordDataset("output.tfrecords")
dataset = dataset.map(parser)
dataset = dataset.batch(batch_size)
NUM_EPOCHS = 2
dataset = dataset.repeat(NUM_EPOCHS)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    x_batch = []
    sess.run((tf.global_variables_initializer(),
			tf.local_variables_initializer()))
    sess.run(iterator.initializer)
    while True:
        try:
            x_train = sess.run(next_element) # 取出设定的
            for i in range(x_train.shape[0]):
                x_batch.append(list(np.fromstring(x_train[i],np.float32)))
            print(x_batch)
        except tf.errors.OutOfRangeError:
            break

x_batch就是恢复出来的数据了,是list列表,根据需要可以转换成其他的形式。
parser解析函数的作用是把数据从TFRecord文件中解析出来。
map(parser)表示对数据集中的每一条数据调用parser方法。
dataset.batch函数是设置一次取出的batch的大小。
dataset.repeat函数表示取全部数据几个epoch,比如我这里设置为2,那么在取完两次全部数据后,就会抛出OutOfRangeError异常,可以以此来停止程序循环。
dataset.shuffle函数将数据集打乱顺序,这里没有展示。
每运行一次sess.run(next_element)就会从TFRecord文件中取出设定的batch大小的数据。取出来的数据的格式是字符串,通过np.fromstring()函数转换成真正的数据。这里要注意转换成的float类型,你原来的数据是float32你转换时就要转换为float32,你如果转换成float64,那肯定数据转换出来是错的。

未理解的地方

取出一个batch的数据之后,如果直接整体操作:

x_train = np.fromstring(x_train,np.float32)

转换出来的数据是不对的,必须要一条一条的转换出来,然后再组合起来才行。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值