TFRecord的优势在于:读取训练集、测试集速度快
TFRecord的层次结构可以表示若干tf.train.example对象,每个tf.train.example对象又由若干个tf.train.Feature字典组成,具体表示如下,其中,tf.train.Feature字典对应训练集、标签
#dataset.tfrecords
[
{#example_1
'feature_1':tf.train.Feature,
...
'feature_n':tf.train.Feature,
},
...
{#example_M
'feature_1':tf.train.Feature,
...
'feature_n':tf.train.Feature,
}
]
训练集、测试集一般为float形式,本文基于float形式的矩阵存储、读取,其中矩阵大小为:样本数目X样本维度,样本数目X标签维度。
TFRecord存储
步骤如下:
(1)读取矩阵
(2)创建tf.train.example对象、tf.train.Feature字典
(3)写入TFRecord文件
代码说明:
1、此代码读取csv文件,然后将其转换为TFRecord文件
2、zip()为python内置函数,如果a=[1,3,5],b=[2,4,6],则zip(a,b)=[(1,2),(3,4),(5,6)]
3、train_x样本,train_y标签
#导入训练集
x_train = np.loadtxt(open("radar2/train_x.csv",'rb'),delimiter=",",skiprows=0)
y_train = np.loadtxt(open("radar2/train_y.csv",'rb'),delimiter=",",skiprows=0)
#导入测试集
x_test = np.loadtxt(open("radar2/test_x.csv",'rb'),delimiter=",",skiprows=0)
y_test = np.loadtxt(open("radar2/test_y.csv",'rb'),delimiter=",",skiprows=0)
with tf.io.TFRecordWriter("radar_train.tfrecords") as writer:
for xtrain,ytrain in zip(x_train,y_train):
feature = {
'xtrain':tf.train.Feature(float_list = tf.train.FloatList(value=xtrain)),
'ytrain':tf.train.Feature(float_list = tf.train.FloatList(value=ytrain)),
}
#创建example对象
example = tf.train.Example(features=tf.train.Features(feature=feature))
#写入TFRecord
writer.write(example.SerializeToString())
with tf.io.TFRecordWriter("radar_test.tfrecords") as writer:
for xtest,ytest in zip(x_test,y_test):
feature = {
'xtest':tf.train.Feature(float_list = tf.train.FloatList(value=xtest)),
'ytest':tf.train.Feature(float_list = tf.train.FloatList(value=ytest))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
TFRecord读取
步骤如下:
(1)创建tf.train.Feature字典,必须和存储时建的同名称
(2)通过tf.data.TFRecordDataset读入原始的TFRecord文件,获得一个tf.data.Dataset数据集对象,后续可以使用keras等框架训练测试
(3)通过Dataset.map执行tf.io.parse_single_example函数将TFRecord文件还原为矩阵形式
代码说明:
1、Dataset.map(f):对数据集中的每个元素应用函数f,得到一个新数据集
2、tf.io.parse_single_example将TFRecord文件还原
3、tf.data.TFRecordDataset载入TFRecord文件
4、‘xtrain’:tf.io.FixedLenFeature([300],tf.float32) 和 ‘ytrain’:tf.io.FixedLenFeature([3],tf.float32) 中的300和3分别为样本维度和标签维度
feature = {
'xtrain':tf.io.FixedLenFeature([300],tf.float32),
'ytrain':tf.io.FixedLenFeature([3],tf.float32),
}
def parser_example_(x):
x = tf.io.parse_single_example(x,feature)
x['xtrain']=tf.reshape(x['xtrain'],[30,10])
return x['xtrain'],x['ytrain']
feature1 = {
'xtest':tf.io.FixedLenFeature([300],tf.float32),
'ytest':tf.io.FixedLenFeature([3],tf.float32),
}
def parser_example_1(x):
x = tf.io.parse_single_example(x,feature1)
x['xtest']=tf.reshape(x['xtest'],[30,10])
return x['xtest'],x['ytest']
dataset = tf.data.TFRecordDataset("radar_train.tfrecords").map(parser_example_)
data_set = tf.data.TFRecordDataset("radar_test.tfrecords").map(parser_example_1)