TensorFlow2:TFRecord存储、读取矩阵

TFRecord的优势在于:读取训练集、测试集速度快
TFRecord的层次结构可以表示若干tf.train.example对象,每个tf.train.example对象又由若干个tf.train.Feature字典组成,具体表示如下,其中,tf.train.Feature字典对应训练集、标签

#dataset.tfrecords
[
 {#example_1
  'feature_1':tf.train.Feature,
  ...
  'feature_n':tf.train.Feature,
 },
 ...
 {#example_M
  'feature_1':tf.train.Feature,
  ...
  'feature_n':tf.train.Feature, 
 }
]

训练集、测试集一般为float形式,本文基于float形式的矩阵存储、读取,其中矩阵大小为:样本数目X样本维度,样本数目X标签维度。

TFRecord存储

步骤如下:
(1)读取矩阵
(2)创建tf.train.example对象、tf.train.Feature字典
(3)写入TFRecord文件
代码说明
1、此代码读取csv文件,然后将其转换为TFRecord文件
2、zip()为python内置函数,如果a=[1,3,5],b=[2,4,6],则zip(a,b)=[(1,2),(3,4),(5,6)]
3、train_x样本,train_y标签

#导入训练集
x_train = np.loadtxt(open("radar2/train_x.csv",'rb'),delimiter=",",skiprows=0) 
y_train = np.loadtxt(open("radar2/train_y.csv",'rb'),delimiter=",",skiprows=0) 
#导入测试集
x_test = np.loadtxt(open("radar2/test_x.csv",'rb'),delimiter=",",skiprows=0)
y_test = np.loadtxt(open("radar2/test_y.csv",'rb'),delimiter=",",skiprows=0) 

with tf.io.TFRecordWriter("radar_train.tfrecords") as writer:
    for xtrain,ytrain in zip(x_train,y_train):
        feature = {
            'xtrain':tf.train.Feature(float_list = tf.train.FloatList(value=xtrain)),
            'ytrain':tf.train.Feature(float_list = tf.train.FloatList(value=ytrain)),
            }
        #创建example对象
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        #写入TFRecord
        writer.write(example.SerializeToString())


with tf.io.TFRecordWriter("radar_test.tfrecords") as writer:
    for xtest,ytest in zip(x_test,y_test):
        feature = {
            'xtest':tf.train.Feature(float_list = tf.train.FloatList(value=xtest)),
            'ytest':tf.train.Feature(float_list = tf.train.FloatList(value=ytest))
            }
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())

TFRecord读取

步骤如下:
(1)创建tf.train.Feature字典,必须和存储时建的同名称
(2)通过tf.data.TFRecordDataset读入原始的TFRecord文件,获得一个tf.data.Dataset数据集对象,后续可以使用keras等框架训练测试
(3)通过Dataset.map执行tf.io.parse_single_example函数将TFRecord文件还原为矩阵形式
代码说明
1、Dataset.map(f):对数据集中的每个元素应用函数f,得到一个新数据集
2、tf.io.parse_single_example将TFRecord文件还原
3、tf.data.TFRecordDataset载入TFRecord文件
4、‘xtrain’:tf.io.FixedLenFeature([300],tf.float32)‘ytrain’:tf.io.FixedLenFeature([3],tf.float32) 中的300和3分别为样本维度和标签维度

feature = {
            'xtrain':tf.io.FixedLenFeature([300],tf.float32),
            'ytrain':tf.io.FixedLenFeature([3],tf.float32),
            }

def parser_example_(x):
    x = tf.io.parse_single_example(x,feature)
    x['xtrain']=tf.reshape(x['xtrain'],[30,10])
    return x['xtrain'],x['ytrain']


   
feature1 = {
            'xtest':tf.io.FixedLenFeature([300],tf.float32),
            'ytest':tf.io.FixedLenFeature([3],tf.float32),
            }

def parser_example_1(x):
    x = tf.io.parse_single_example(x,feature1)
    x['xtest']=tf.reshape(x['xtest'],[30,10])
    return x['xtest'],x['ytest']

dataset = tf.data.TFRecordDataset("radar_train.tfrecords").map(parser_example_)
data_set = tf.data.TFRecordDataset("radar_test.tfrecords").map(parser_example_1)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值