TensorFlow2实战-系列教程7:TFRecords数据源制作1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、TFRecords

在训练过程中,基本都是使用GPU来计算,但是取一个一个batch取数据还是必须要用cpu,这个过程耗费时间也会影响训练时间,制作TFRecords可以有效解决这个问题,此外制作TFRecords数据可以更好的管理存储数据

为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法,tf.Example类是一种将数据表示为{“string”: value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据

通常情况下,tf.Example中可以使用以下几种格式:

  • tf.train.BytesList: 可以使用的类型包括 string和byte
  • tf.train.FloatList: 可以使用的类型包括 float和double
  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64

TFRecords是TensorFlow官方推荐的

2、转化示例

def _bytes_feature(value):
    """Returns a bytes_list from a string/byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Return a float_list form a float/double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Return a int64_list from a bool/enum/int/uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

定义3个函数分别对3种类型的数据进行转换成对应的TensorFlow的数据格式

# tf.train.BytesList
print(_bytes_feature(b'test_string'))
print(_bytes_feature('test_string'.encode('utf8')))

# tf.train.FloatList
print(_float_feature(np.exp(1)))

# tf.train.Int64List
print(_int64_feature(True))
print(_int64_feature(1))

传进几个numpy格式的数据,再调用上面的函数进行转换,再打印:

bytes_list { value: “test_string” }
bytes_list { value: “test_string” }
float_list { value: 2.7182817459106445 }
int64_list { value: 1 }
int64_list { value: 1 }

3、TFRecords制作方法

def serialize_example(feature0, feature1, feature2, feature3):
    """
    创建tf.Example
    """
    # 转换成相应类型
    feature = {
        'feature0': _int64_feature(feature0),
        'feature1': _int64_feature(feature1),
        'feature2': _bytes_feature(feature2),
        'feature3': _float_feature(feature3),
    }
    #使用tf.train.Example来创建
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    #SerializeToString方法转换为二进制字符串
    return example_proto.SerializeToString()
  1. 定义一个函数,传入4个参数
  2. 使用前面定义的函数对4个参数分别转换成相应的格式
  3. 构建Example将转换完的数据创建一条数据
  4. 序列化 tf.Example:返回一个二进制的字符串
n_observations = int(1e4)
feature0 = np.random.choice([False, True], n_observations)
feature1 = np.random.randint(0, 5, n_observations)
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
feature3 = np.random.randn(n_observations)
  1. 定义一个一万备用
  2. 随机选择一万个布尔数据
  3. 随机选择一万个0、1、2、3、4这5个整数
  4. 随机构造字符串
  5. 随机构造浮点数
filename = 'tfrecord-1'

with tf.io.TFRecordWriter(filename) as writer:
    for i in range(n_observations):
        example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
        writer.write(example)
  1. 定义文件名
  2. 定义一个写的模块,传进文件名,写入数据
  3. 迭代一万次
  4. 按照零到一万的索引,分别传入上面构造的4个特征
  5. 写入数据

这段代码执行后,会得到一个名为tfrecord-1的文件:
在这里插入图片描述

4、加载tfrecord文件

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

打印结果:

<TFRecordDatasetV2 shapes: (), types: tf.string>

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器学习杨卓越

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值