tf.data.TFRecordDataset 解析-带实例解析函数-边做边记

很多东西自己只是查了、用了,但却没有去总结吸收,还是需要一个个的去积累记录

目录

tf.data.TFRecordDataset   

 含义

 Dataset

函数解析

tf.data.Dataset.batch 、map、shuffle、repeat

 tf.data.Dataset.make_one_shot_iterator().get_next()

一个几乎通用tfrecord 的数据解析函数

tf.data.TFRecordDataset   

 含义

Class TFRecordDataset。A Dataset comprising records from one or more TFRecord files.   
Creates a TFRecordDataset to read for one or more TFRecord files.   
**即读tfrecord文件创建一个TFRecordDataset类的实例对象。其父类是Dataset。(也即是创建了一个tf.data.Dataset的对象)。  tf.data.Dataset有的方法,他会全都继承。**  

 Dataset

 **Dataset 是基类,表示一串元素(elements),其中每个元素包含了一或多个Tensor对象。例如:在一个图片pipeline中,一个元素可以是单个训练样本,它们带有一个表示图片数据的tensors和一个label组成的pair。包括了创造和变换(transform)datasets的方法,同时也允许从内存中的数据来初始化dataset。**  
TextLineDataset 从文本文件中读取行数据   
TFRecordDataset 从TFRecord文件中读取records   
FixLengthRecordDataset 从二进制文件中读取固定长度records   
Iterator 它提供了主要的方式来从一个dataset中抽取元素。通过Iterator.get_next() 返回的该操作会yields出Datasets中的下一个元素,作为输入pipeline和模型间的接口使用   
一个dataset由element组成,它们每个都具有相同的结构,一个element包含了一个或多个tf.Tensor对象,称为component,每个component都具有   
<br>  

函数解析

**TFRecordDataset** 解析tfrecord文件中的所有记录,可以直接使用dataset的map方法;也可使用直接使用这个类的方法repeat、shuffle、batch方法对dataset进行重复、混洗、分批;  
参数:(   
    filenames,   
    compression_type=None,
    buffer_size=None,
    num_parallel_reads=None   
)
一般只传第一个参数filenames即可   
filenames: A tf.string tensor or tf.data.Dataset containing one or more filenames.      
compression_type: (Optional.) A tf.string scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. 0 means no buffering.  
num_parallel_reads: (Optional.) A tf.int64 scalar representing the number of files to read in parallel. Defaults to reading files sequentially.       

 

tf.data.Dataset.batch 、map、shuffle、repeat

batch ( batch_size,
drop_remainder=False
)
参数: batch_size :表示一次迭代的batch中去数据量的条数。 A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
drop_remainder:最后一个batch数据是否舍弃 (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
返回: Dataset: A Dataset.
batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:
dataset = dataset.batch(32)

shuffle
(
buffer_size,
seed=None,
reshuffle_each_iteration=None
)
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。Randomly shuffles the elements of this dataset.
参数:
buffer_size: A tf.int64 scalar tf.Tensor, representing the number of elements from this dataset from which the new dataset will sample.
seed: (Optional.) A tf.int64 scalar tf.Tensor, representing the random seed that will be used to create the distribution. See tf.set_random_seed for behavior.
reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to True.)
返回 : Dataset: A Dataset.

repeat
(count=None)
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch. Repeats this dataset count times . count: (Optional.) A tf.int64 scalar tf.Tensor, representing the number of times the dataset should be repeated. The default behavior (if count is None or -1) is for the dataset be repeated indefinitely.
返回 : Dataset: A Dataset.
假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:
dataset = dataset.repeat(5)
如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:  

 

 tf.data.Dataset.make_one_shot_iterator().get_next()

将dataset中的元素取出。Creates an Iterator for enumerating the elements of this dataset.
如何将这个dataset中的元素取出呢?方法是从Dataset中示例化一个Iterator,然后对Iterator进行迭代。 语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素当在非Eager模式下时。one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。   

 

一个几乎通用tfrecord 的数据解析函数

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Site    :
# @Software: PyCharm
# @License :   (C)Copyright 2017-2018, Liugroup-NLPR-CASIA

# coding:utf-8

import os
import sys
import random
import tensorflow as tf
from collections import namedtuple
import json


def get_did():
        with open('./tower.did', 'r') as fr:
                line = fr.readline()
                _did = json.loads(line.strip())
        return _did


def parse_query_features(filename):
        ## tfrecord格式数据Features
        # _did = get_did()
        # print(_did)
        with tf.Session() as sess:
                dataset = tf.data.TFRecordDataset(tf.gfile.Glob(filename))
                # 下面batch里面的值控制输出多少条数据
                dataset = dataset.batch(2)
                batch_data = dataset.make_one_shot_iterator().get_next()
                data = sess.run(batch_data)
                i = 0
                for i in range(len(data)):
                        if i % 1000 == 0:
                                print(i)
                        example = tf.train.Example.FromString(data[i])
                        # str to example
                        features = example.features.feature
                        # print(features.keys())
                        # print ('!!!FEATURES: ',type(features))
                        # print("样本feature")
                        fea = {}
                        for ind, key in enumerate(features):
                                feature = features[key]
                                ftype = None
                                fvalue = None
                                if len(feature.bytes_list.value) > 0:
                                        ftype = 'bytes_list'
                                        fvalue = feature.bytes_list.value

                                if len(feature.float_list.value) > 0:
                                        ftype = 'float_list'
                                        fvalue = feature.float_list.value

                                if len(feature.int64_list.value) > 0:
                                        ftype = 'int64_list'
                                        fvalue = feature.int64_list.value
                                #    fea[key] = fvalue
                                print("{}\t{}\t{}\t{}".format(ind, key, ftype, fvalue))
                # if fea["1000011_s_u_did_debug"][0] in _did:
                #    print(fea)

## 统计某个record的part数据的大小。
def countTfRecord(file_path):
        count=0
        for record in tf.python_io.tf_record_iterator(file_path):
                count+=1
        print("数据{} 的样本条数为\t{}".format(file_path,count))

if __name__ == '__main__':
        # print(get_did())
        # 输入路径即可
        data_path = "./part-r-00000"

        #解析数据
        parse_query_features(data_path)        #统计某个样本的数据条数
        # countTfRecord(data_path)


参考与鸣谢:

何之源的Dataset API入门教程  : https://zhuanlan.zhihu.com/p/30751039   
Tensorflow tf.data.Dataset API学习笔记讲解,讲的很不错。https://zhuanlan.zhihu.com/p/37106443   
tensorflow入门:tfrecord 和tf.data.TFRecordDataset https://blog.csdn.net/yeqiustu/article/details/79793454     

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值