tf.train.Example的用法

前言

最近在看到一个代码时,里面用到了tf.train.Example,于是学习了其用法,这里记录一下,也希望能对其他朋友有用。
另外,本文涉及的代码基于python 3.6.5 tensorflow 1.8.0
tf.train.Example主要用在将数据处理成二进制方面,一般是为了提升IO效率和方便管理数据。下面按调用顺序介绍使用tf.train.Example涉及的几个类。

tf.train.BytesList等

现在我们有一个data.txt文件,内容如下:

21
This is a test data file.
We will convert this text file to bin file.

文件中第一行是个整数,第二行和第三行都是字符串。这是我们处理的原始数据。
我们先使用下面的代码将数据读进来:

import struct
import tensorflow as tf


def read_text_file(text_file):
    lines = []
    with open(text_file, "r") as f:
        for line in f:
            lines.append(line.strip())
    return lines

def text_to_binary(in_file, out_file):
    inputs = read_text_file(in_file)

    with open(out_file, 'wb') as writer:
    	pass

if __name__ == '__main__':
    text_to_binary('data.txt', 'data.bin')

格式化原始数据可以使用tf.train.BytesList tf.train.Int64List tf.train.FloatList三个类。这三个类都有实例属性value用于我们将值传进去,一般tf.train.Int64List tf.train.FloatList对应处理整数和浮点数,tf.train.BytesList用于处理其他类型的数据。
这里第一行数据我们可以用tf.train.Int64List处理,而第二、第三行数据我们使用tf.train.BytesList处理。下面我们看代码实现,我们将上述代码的pass替换如下:

        data_id = tf.train.Int64List(value=[int(inputs[0])])
        data = tf.train.BytesList(value=[bytes(' '.join(inputs[1:]), encoding='utf-8')])

注意到,tf.train.Int64List的value值需要是int数据的列表,而tf.train.BytesList的value值需要是bytes数据的列表。
我们分别打印data_id和data的值可以看到:

value: 21

value: "This is a test data file. We will convert this text file to bin file."

这样我们就完成了第一步操作。

tf.train.Feature

tf.train.Feature有三个属性为tf.train.bytes_list tf.train.float_list tf.train.int64_list,显然我们只需要根据上一步得到的值来设置tf.train.Feature的属性就可以了,如下所示:

tf.train.Feature(int64_list=data_id)
tf.train.Feature(bytes_list=data)

tf.train.Features

从名字来看,我们应该能猜出tf.train.Featurestf.train.Feature的复数,事实上tf.train.Features有属性为feature,这个属性的一般设置方法是传入一个字典,字典的key是字符串(feature名),而值是tf.train.Feature对象。因此,我们可以这样得到tf.train.Features对象:

        feature_dict = {
            "data_id": tf.train.Feature(int64_list=data_id),
            "data": tf.train.Feature(bytes_list=data)
        }
        features = tf.train.Features(feature=feature_dict)

tf.train.Example

终于到我们的主角了。tf.train.Example有一个属性为features,我们只需要将上一步得到的结果再次当做参数传进来即可。
另外,tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
当然,既然有对象序列化为字符串的方法,那么肯定有从字符串反序列化到对象的方法,该方法是FromString(),需要传递一个tf.train.Example对象序列化后的字符串进去做为参数才能得到反序列化的对象。
在我们这里,只需要构建tf.train.Example对象并序列化就可以了,这一步的代码为:

        example = tf.train.Example(features=features)
        example_str = example.SerializeToString()

好了,那么现在我们看一下将data.txt处理成data.bin的完整代码:

import struct
import tensorflow as tf


def read_text_file(text_file):
    lines = []
    with open(text_file, "r") as f:
        for line in f:
            lines.append(line.strip())
    return lines


def text_to_binary(in_file, out_file):
    inputs = read_text_file(in_file)

    with open(out_file, 'wb') as writer:
        data_id = tf.train.Int64List(value=[int(inputs[0])])
        data = tf.train.BytesList(value=[bytes(' '.join(inputs[1:]), encoding='utf-8')])

        feature_dict = {
            "data_id": tf.train.Feature(int64_list=data_id),
            "data": tf.train.Feature(bytes_list=data)
        }
        features = tf.train.Features(feature=feature_dict)

        example = tf.train.Example(features=features)
        example_str = example.SerializeToString()

        str_len = len(example_str)

        writer.write(struct.pack('H', str_len))
        writer.write(struct.pack('%ds' % str_len, example_str))


if __name__ == '__main__':
    text_to_binary('data.txt', 'data.bin')

代码里还涉及到了struct模块,关于struct模块的用法可以参考我的这篇文章:Python二进制数据处理

  • 81
    点赞
  • 133
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值