TFrecord读取 ——数据过大问题
- TFrecord读取基本代码
def get_batch(self, input_file, batch_size, is_train=True):
# tf解析
def _parse_function(example_proto):
name_to_features = {
"img": tf.io.FixedLenFeature([], tf.string),
"float_label": tf.io.FixedLenFeature([3], tf.float32),
"label": tf.io.FixedLenFeature([1], tf.int64)
}
example = tf.io.parse_single_example(example_proto, name_to_features)
# 对图片进行解码
img = tf.decode_raw(example['img'], tf.uint8)
# 其它值可直接提取
float_label = example['float_label']
label = example['label']
return img,float_label,label
dataset = tf.data.TFRecordDataset(input_file)
if is_train:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(batch_size)
iterator = tf.data.make_initializable_iterator(dataset)
batch = iterator.get_next()
return iterator, batch
with tf.Session() as sess:
# 在sess里面初始化
iterator, batch = get_batch(train_file, batch_size)
sess.run(iterator.initializer)
# 运行每个batch
example = sess.run(batch)
-
TFrecord读取数据时,为了保证获取数据的随机性,需要下面这一行代码。
dataset = dataset.shuffle(buffer_size=10000)
buffer_size的大小代表有多少元素放入了缓存中,并在训练时随机从缓存的元素中取出用于训练。为了保证抽取数据的随机性,应该设置buffer_size=训练元素的数量。但当我们的训练集很大的时候,会占用很大的内存,甚至会内存溢出。
-
解决方案:
-
首先根据样本总数量,制作生成多个小tfrecord文件。
# 计算tfrecord数量 sample_nums = 0 for record in tf.python_io.tf_record_iterator(input_file): sample_nums += 1
-
将tfrecords文件先随机排列,使用tensorflow的interleave函数读取多个tfrecords文件。
# tf_list:所有tfrecord的路径列表 files = tf.data.Dataset.list_files(tf_list, shuffle=True) # cycle_length:interleave同时读取的文件数目 dataset = files.interleave(map_func=tf.data.TFRecordDataset,cycle_length=1)
-
设置shuffle的buffer_size=小的tfrecords文件的元素数目,buffer_size可设置为1000。
shuffle_buffer_size = 1000 dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)