tensorflow笔记 tfrecord创建及读取

之前很少仔细看tf的一些基础api,只要能跑通就过了,最近打算花时间把部分基础api整理一下,方便以后使用。


简介

tfrecord是tensorflow训练模型时比较常用的处理大量数据的格式。简单来说,一种二进制数据储存格式,比一次性读取csv或jpg数据要更快,且占用更小的内存。

tfrecord

理论上tfrecord可以保存任意格式的数据。官方给出可以储存的数据格式有三种,FloatList,Int64List,BytesList。储存的tfrecord文件由一个个Example组成,Example是 protocolbuf 协议下的消息体。每一个 Example 包含了一系列的 feature 属性。每一个 feature 包含了一个 key和对应的一个或多个value 。example的具体格式后面会给出示例。

生成tfrecord文件

以一个简单的分类问题数据集为例,feature是一个1x5的向量,label取值为0或1

import numpy as np
import tensorflow as tf

#构建一个简单的分类问题数据集,feature为一个1x5的随机向量,label取值为0或1

#生成10个随机样本,其中一半样本label为0,另一半为1
n = 10
size = (n, 5)

x_data = np.random.randint(0, 10, size=size)
y1_data = np.ones((n//2, 1), int)
y2_data = np.zeros((n//2, 1), int)
y_data = np.vstack((y1_data, y2_data))
np.random.shuffle(y_data)
xy_data = np.hstack((x_data,y_data))
#print(xy_data)
'''
[[2 0 0 5 8 1]
 [8 3 7 5 1 1]
 [3 5 7 8 7 1]
 [5 2 7 9 9 0]
 [0 1 0 3 0 0]
 [0 3 4 2 5 0]
 [4 8 8 3 8 1]
 [3 5 2 7 7 0]
 [0 4 7 7 3 1]
 [5 0 2 4 9 0]]
'''

#储存为tfrecord格式,文件名以.record为后缀
tfrecord_path = 'data.record'
writer = tf.python_io.TFRecordWriter(tfrecord_path)
for i in range(n):
	#Features要求输入格式为list,所以读入的数据需要先转化为list
    sample = x_data[i] 
    label = int(y_data[i])
    example = tf.train.Example(features=tf.train.Features(feature={
   
        'sample':
            tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
        'label':
            tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
    }))
    writer.write(example.SerializeToString())
    #print(example)
    #print(example.SerializeToString())
writer.close()

'''
example格式:
features {
  feature {
    key: "label"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: &
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值