1.序言
该模块是tensorflow用来处理tfrecords文件的接口,定义在tensorflow/python/lib/io/python_io.py,主要包含了四个部分:
class TFRecordCompressionType:记录的压缩类型。
class TFRecordOptions:用于操作TFRecord文件的选项。
class TFRecordWriter:将记录写入TFRecords文件的类。
tf_record_iterator(…):从TFRecords文件中读取记录的迭代器
2.源码解析
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
#该类定义了tfrecords文件压缩类型:无,ZLIB,GZIP三种
class TFRecordCompressionType(object):
"""The type of compression for the record."""
NONE = 0
ZLIB = 1
GZIP = 2
# 这个类会转换为proto格式,以便与C++接口对接
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
compression_type_map = {
TFRecordCompressionType.ZLIB: "ZLIB",
TFRecordCompressionType.GZIP: "GZIP",
TFRecordCompressionType.NONE: ""
}
def __init__(self, compression_type):
self.compression_type = compression_type
@classmethod
def get_compression_type_string(cls, options):
if not options:
return ""
return cls.compression_type_map[options.compression_type]
def tf_record_iterator(path, options=None):
"""从tfrecords文件读取数据的迭代器.
参数:
path: TFRecords文件路径.
options: 读取选项,主要是压缩类型,TFRecordOptions对象.
yields:
Strings.
异常:
IOError: 路径不正确是引发.
"""
compression_type = TFRecordOptions.get_compression_type_string(options)
with errors.raise_exception_on_not_ok_status() as status:
reader = pywrap_tensorflow.PyRecordReader_New(
compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)#读取器,pywarp_tensorflow包装所以的符号,这里定义了一个文件读取器对象
if reader is None:
raise IOError("Could not open %s." % path)
while True:
try:
with errors.raise_exception_on_not_ok_status() as status:
reader.GetNext(status)
except errors.OutOfRangeError:
break
yield reader.record() #逐步读取文件
reader.Close()
class TFRecordWriter(object):
"""tfrecords文件写操作类,由于实施了__enter__和__exit__接口,根据Python的上下文管理机制,可以用with语句
"""
# TODO(josh11b): Support appending?
def __init__(self, path, options=None):
"""打开文件,并初始化写对象
参数:
path: 文件路径
options: 选项,TFRecordOptions对象
Raises:
IOError: If `path` cannot be opened for writing.
"""
compression_type = TFRecordOptions.get_compression_type_string(options)#获取压缩类型
with errors.raise_exception_on_not_ok_status() as status:
self._writer = pywrap_tensorflow.PyRecordWriter_New(
compat.as_bytes(path), compat.as_bytes(compression_type), status)#定义writer
def __enter__(self):
"""进入with语句块"""
return self
def __exit__(self, unused_type, unused_value, unused_traceback):
"""退出with语句块,并关闭文件"""
self.close()
def write(self, record):
"""想文件中写入一条记录.
Args:
record: str
"""
self._writer.WriteRecord(record)#实际是由writer实现
def flush(self):
"""刷新缓冲区内容到磁盘文件"""
with errors.raise_exception_on_not_ok_status() as status:
self._writer.Flush(status)
def close(self):
"""关闭文件"""
with errors.raise_exception_on_not_ok_status() as status:
self._writer.Close(status)