之前很少仔细看tf的一些基础api,只要能跑通就过了,最近打算花时间把部分基础api整理一下,方便以后使用。
简介
tfrecord是tensorflow训练模型时比较常用的处理大量数据的格式。简单来说,一种二进制数据储存格式,比一次性读取csv或jpg数据要更快,且占用更小的内存。
tfrecord
理论上tfrecord可以保存任意格式的数据。官方给出可以储存的数据格式有三种,FloatList,Int64List,BytesList。储存的tfrecord文件由一个个Example组成,Example是 protocolbuf 协议下的消息体。每一个 Example 包含了一系列的 feature 属性。每一个 feature 包含了一个 key和对应的一个或多个value 。example的具体格式后面会给出示例。
生成tfrecord文件
以一个简单的分类问题数据集为例,feature是一个1x5的向量,label取值为0或1
import numpy as np
import tensorflow as tf
#构建一个简单的分类问题数据集,feature为一个1x5的随机向量,label取值为0或1
#生成10个随机样本,其中一半样本label为0,另一半为1
n = 10
size = (n, 5)
x_data = np.random.randint(0, 10, size=size)
y1_data = np.ones((n//2, 1), int)
y2_data = np.zeros((n//2, 1), int)
y_data = np.vstack((y1_data, y2_data))
np.random.shuffle(y_data)
xy_data = np.hstack((x_data,y_data))
#print(xy_data)
'''
[[2 0 0 5 8 1]
[8 3 7 5 1 1]
[3 5 7 8 7 1]
[5 2 7 9 9 0]
[0 1 0 3 0 0]
[0 3 4 2 5 0]
[4 8 8 3 8 1]
[3 5 2 7 7 0]
[0 4 7 7 3 1]
[5 0 2 4 9 0]]
'''
#储存为tfrecord格式,文件名以.record为后缀
tfrecord_path = 'data.record'
writer = tf.python_io.TFRecordWriter(tfrecord_path)
for i in range(n):
#Features要求输入格式为list,所以读入的数据需要先转化为list
sample = x_data[i]
label = int(y_data[i])
example = tf.train.Example(features=tf.train.Features(feature={
'sample':
tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
'label':
tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
writer.write(example.SerializeToString())
#print(example)
#print(example.SerializeToString())
writer.close()
'''
example格式:
features {
feature {
key: "label"
value {
int64_list {
value: 1
}
}
}
feature {
key: &