引言
tensorflow-IO提供了丰富的对接文件系统的接口,本文从源代码出发分析tensorflow-IO的文件系统的相关接口,学习如何为这些以C编写的文件系统提供Python接口。
从使用出发
以最简单的MINIST手写数据集训练出发,分析函数的调用过程。首先使用tensorflow_io需要pip install对应的库。之后就可以在使用tensorflow_io加载对应的数据集。
import tensorflow_io as tfio
dataset_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
d_train = tfio.IODataset.from_mnist(
dataset_url + "train-images-idx3-ubyte.gz",
dataset_url + "train-labels-idx1-ubyte.gz",
)
(xs, ys), _ ,path= datasets.mnist.load_data() # 自动下载
我们知道tensorflow中的dataset是已经处理好的tensor,本文的重点不在数据预处理而是重点关注数据从磁盘如何加载的,方便之后在这方面做出改进。持续跟踪代码
@classmethod
def from_mnist(cls, images=None, labels=None, **kwargs):
with tf.name_scope(kwargs.get("name", "IOFromMNIST")):
return mnist_dataset_ops.MNISTIODataset(
images, labels, internal=True, **kwargs
)
def MNISTIODataset(images=None, labels=None, internal=True):
"""MNISTIODataset"""
assert internal, (
"MNISTIODataset constructor is private; please use one "
"of the factory methods instead (e.g., "
"IODataset.from_mnist())"
)
assert (
images is not None or labels is not None
), "images and labels could not be all None"
images_dataset = MNISTImageIODataset(images) if images is not None else None
labels_dataset = MNISTLabelIODataset(labels) if labels is not None else None
if images is None:
return labels_dataset
if labels is None:
return images_dataset
return tf.data.Dataset.zip((images_dataset, labels_dataset))
class MNISTImageIODataset(tf.data.Dataset):
def __init__(self, filename):
_, compression = core_ops.io_file_info(filename)
rows = tf.io.decode_raw(
core_ops.io_file_read(filename, 8, 4, compression=compression),
tf.int32,
little_endian=False,
)
cols = tf.io.decode_raw(
core_ops.io_file_read(filename, 12, 4, compression=compression),
tf.int32,
little_endian=False,
)
会发现最终文件路径被传入core_ops.io_file_info()函数中,而这个io_file_info()函数是没看不到源代码的。只能先看一下core_ops的定义,如下。
core_ops = LazyLoader("core_ops", "libtensorflow_io.so")
这里能够看到core_ops是一个python中的module模块,而这个模块的功能是从libtensorflow_io.so这个二文件(其实为动态链接库)中加载的。LazyLoader函数最终将调用_load_library()函数,传入文件的名字。注意这个"core_ops"并没有实际的作用,只是module对象的名字。
class LazyLoader(types.ModuleType):
def __init__(self, name, library):
self._mod = None
self._module_name = name
self._library = library
super().__init__(self._module_name)
def _load(self):
if self._mod is None:
self._mod = _load_library(self._library)
return self._mod
def __getattr__(self, attrb):
return getattr(self._load(), attrb)
def __dir__(self):
return dir(self._load())
在_load_library中会根据功能的具体分类来调用不同的方法加载。
def _load_library(filename, lib="op"):
"""_load_library"""
f = inspect.getfile(sys._getframe(1)) # pylint: disable=protected-access
# Construct filename
f = os.path.join(os.path.dirname(f), filename)
filenames = [f]
# Add datapath to load if en var is set, used for running tests where shared
# libraries are built in a different path
datapath = os.environ.get("TFIO_DATAPATH")
if datapath is not None:
# Build filename from:
# `datapath` + `tensorflow_io` + `package_name` + `relpath_to_library`
rootpath = os.path.dirname(sys.modules["tensorflow_io"].__file__)
filename = sys.modules[__name__].__file__
f = os.path.join(
datapath,
"tensorflow_io",
os.path.relpath(os.path.dirname(filename), rootpath),
os.path.relpath(f, os.path.dirname(filename)),
)
filenames.append(f)
# Function to load the library, return True if file system library is loaded
if lib == "op":
load_fn = tf.load_op_library
elif lib == "dependency":
load_fn = lambda f: ctypes.CDLL(f, mode=ctypes.RTLD_GLOBAL)
elif lib == "fs":
load_fn = lambda f: tf.experimental.register_filesystem_plugin(f) is None
else:
load_fn = lambda f: tf.compat.v1.load_file_system_library(f) is None
# Try to load all paths for file, fail if none succeed
errs = []
for f in filenames:
try:
l = load_fn(f)
if l is not None:
return l
except (tf.errors.NotFoundError, OSError) as e:
errs.append(str(e))
raise NotImplementedError(
"unable to open file: "
+ "{}, from paths: {}\ncaused by: {}".format(filename, filenames, errs)
)
可以很明显的看到,当lib为’op’时调用tf.load_op_library,当lib为’op’时调用。而当继续向下时,load_op_library里面的函数以TF开头的都看不到源码了。如下:
def TF_LoadLibrary(arg0): # real signature unknown; restored from __doc__
""" TF_LoadLibrary(arg0: str) -> tensorflow.python.client._pywrap_tf_session.TF_Library """
pass
说明这个的具体实现已经又C代码编译成动态链接库了,这里只是编译器为了给用户看一下大致的接口定义。也就是说甚至这个接口的定义都有可能是错误的。
从tensorflow_io源码入手
从动态链接库文件入手
既然我们知道core_ops的来源是libtensorflow_io.so的动态链接库,那么我们从构建这个动态链接库的过程入手去看源代码。
首先tensorflow使用的是bazel项目构建工具,bazel将每个有Build文件的目录划分为包。我们关注tensorflow_io的核心代码包部分。在tensorflow_io的核心代码目录/tensorflow_io/core/下我们找到libtensorflw_io的BUILD.bzl文件。打开这个文件我们能看到libtensorflow_io.so的构建设置
cc_binary(
name = "python/ops/libtensorflow_io.so",
copts = tf_io_copts(),
linkshared = 1,
deps = [
"//tensorflow_io/core:arrow_ops",
"//tensorflow_io/core:bigquery_ops",
"//tensorflow_io/core:audio_video_ops",
"//tensorflow_io/core:avro_ops",
"//tensorflow_io/core:orc_ops",
"//tensorflow_io/core:cpuinfo",
"//tensorflow_io/core:file_ops",
"//tensorflow_io/core:filesystem_ops",
"//tensorflow_io/core:grpc_ops",
"//tensorflow_io/core:hdf5_ops",
"//tensorflow_io/core:image_ops",
"//tensorflow_io/core:json_ops",
"//tensorflow_io/core:kafka_ops",
"//tensorflow_io/core:kinesis_ops",
"//tensorflow_io/core:lmdb_ops",
"//tensorflow_io/core:numpy_ops",
"//tensorflow_io/core:parquet_ops",
"//tensorflow_io/core:pcap_ops",
"//tensorflow_io/core:pulsar_ops",
"//tensorflow_io/core:obj_ops",
"//tensorflow_io/core:operation_ops",
"//tensorflow_io/core:pubsub_ops",
"//tensorflow_io/core:serialization_ops",
"//tensorflow_io/core:sql_ops",
"//tensorflow_io/core:text_ops",
"//tensorflow_io/core:ignite_ops",
"//tensorflow_io/core:mongodb_ops",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + select({
"@bazel_tools//src/conditions:windows": [],
"//conditions:default": [
"//tensorflow_io/core:core_ops",
"//tensorflow_io/core:elasticsearch_ops",
"//tensorflow_io/core:genome_ops",
"//tensorflow_io/core:optimization",
"//tensorflow_io/core:oss_ops",
"//tensorflow_io/core/kernels/gsmemcachedfs:gs_memcached_file_system",
],
}) + select({
"//tensorflow_io/core:static_build_on": [
"//tensorflow_io/core/filesystems:filesystem_plugins",
],
"//conditions:default": [],
}),
)
cc_binary中的deps选项是构建动态链接库所需要的其他库,由于函数core_ops.io_file_info()是对文件的访问操作,所以我们具体看一下file_ops,file_ops模块如下:
cc_library(
name = "file_ops",
srcs = [
"kernels/file_kernels.cc",
"ops/file_ops.cc",
],
copts = tf_io_copts(),
linkstatic = True,
deps = [
"//tensorflow_io/core:dataset_ops",
],
alwayslink = 1,
)
最终我们定位到kernels/file_kernels.cc文件和ops/files_ops.cc文件。
首先来看ops/file_ops.cc文件
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace io {
namespace {
REGISTER_OP("IO>FileInfo")
.Input("input: string")
.Output("length: int64")
.Output("compression: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
c->set_output(1, c->Scalar());
return Status::OK();
});
REGISTER_OP("IO>FileRead")
.Input("input: string")
.Input("offset: int64")
.Input("length: int64")
.Input("compression: string")
.Output("value: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("IO>FileInit")
.SetIsStateful()
.Input("input: string")
.Output("resource: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("IO>FileCall")
.SetIsStateful()
.Input("input: string")
.Input("final: bool")
.Input("resource: resource")
.Output("output: string")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("IO>FileSync")
.Input("resource: resource")
.SetShapeFn(shape_inference::ScalarShape);
} // namespace
} // namespace io
}
很明显的看到,这里调用tensorflow中op.h定义的operation相关接口将文件的操作注册成tensorflow的计算图节点(在tensorflow中被称为kerenal)。
Operation接口的实现
在将Operation接口注册过后需要为运算提供实现。具体是实现tensorflow中的OpKernel接口
class OpKernel {
public:
OpKernel(OpKernelConstruction* context, bool is_deferred);
OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
bool is_deferred);
virtual ~OpKernel();
virtual void Compute(OpKernelContext* context) = 0;
// Returns nullptr iff this op kernel is synchronous.
virtual AsyncOpKernel* AsAsync() { return nullptr; }
virtual bool IsExpensive() { return expensive_; }
// Returns a pointer to the tensor stored inside constant ops.
virtual const Tensor* const_tensor() const { return nullptr; }
// Accessors.
const NodeDef& def() const { return props_->node_def; }
const string& name() const { return props_->node_def.name(); }
absl::string_view name_view() const { return name_view_; }
const string& type_string() const { return props_->node_def.op(); }
absl::string_view type_string_view() const { return type_string_view_; }
const string& requested_input(int i) const {
return props_->node_def.input(i);
}
const string& requested_device() const { return props_->node_def.device(); }
int num_inputs() const { return props_->input_types.size(); }
DataType input_type(int i) const { return props_->input_types[i]; }
const DataTypeVector& input_types() const { return props_->input_types; }
const MemoryTypeVector& input_memory_types() const {
return input_memory_types_;
}
int num_outputs() const { return props_->output_types.size(); }
DataType output_type(int o) const { return props_->output_types[o]; }
const DataTypeVector& output_types() const { return props_->output_types; }
const MemoryTypeVector& output_memory_types() const {
return output_memory_types_;
}
Status InputRange(StringPiece input_name, int* start, int* stop) const;
Status OutputRange(StringPiece output_name, int* start, int* stop) const;
// Returns `true` if and only if this kernel uses deferred execution.
bool is_deferred() const { return is_deferred_; }
virtual string TraceString(OpKernelContext* ctx, bool verbose);
protected:
string GetTraceArgument(OpKernelContext* ctx);
private:
const std::shared_ptr<const NodeProperties> props_;
const MemoryTypeVector input_memory_types_;
const MemoryTypeVector output_memory_types_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
const absl::string_view name_view_;
const absl::string_view type_string_view_;
const int graph_def_version_;
const bool is_deferred_;
bool expensive_;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
};
Operator类中定义了一些操作核的属性,如节点定义、名称、类型、输入和输出等。参考tensorflow的官方文档,我们只需要重写Compute方法就能重新定义节点的计算内容。详细介绍可以看下这篇博客,这里不做更多阐述。回到tensorflow_io中的file_kernel.cc以FileInfoOp举例:
class FileInfoOp : public OpKernel {
public:
explicit FileInfoOp(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
string input = input_tensor->scalar<tstring>()();
uint64 size;
OP_REQUIRES_OK(context, env_->GetFileSize(input, &size));
std::unique_ptr<tensorflow::RandomAccessFile> file;
OP_REQUIRES_OK(context, env_->NewRandomAccessFile(input, &file));
Tensor* length_tensor = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape({}), &length_tensor));
length_tensor->scalar<int64>()() = size;
Tensor* compression_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
&compression_tensor));
StringPiece result;
char buffer[10] = {0};
Status status = file->Read(0, 10, &result, buffer);
if (!status.ok() || result.size() != 10) {
return;
}
// deflation- third byte must be 0x08.
if (memcmp(buffer, "\x1F\x8B\x08", 3) != 0) {
return;
}
// No reserved flags set.
if ((buffer[3] & 0xE0) != 0) {
return;
}
compression_tensor->scalar<tstring>()() = "GZIP";
}
private:
mutable mutex mu_;
Env* env_ TF_GUARDED_BY(mu_);
};
源代码中的OP_REQUIRES_OK宏定义类似assert方法,主要用于判断context相关状态是否OK,具体以OP_REQUIRES_OK(context, context->input(“input”, &input_tensor))举例。它传入一个只想OpKernelContext对象的指针context,随后执行context指向对象中的input方法,并且根据input()方法返回的状态值进行不同的行为,如果不为OK则会在运行时抛出一个错误。具体的宏定义在/tensorflow/core/framework/op_requires.h中。
通过这个函数再和最先开始的core_ops.io_file_info()函数做对比,能够自然的推出其实就是调用的这个C的函数,不过Python中不使用驼峰命名方式而是蛇形命名法。
调用文件系统的接口
FileInfoOp类中对文件的操作是使用的tensorflow中的RandomAccessFile,以上面的FileInfoOp举例:
std::unique_ptr<tensorflow::RandomAccessFile> file;
OP_REQUIRES_OK(context, env_->NewRandomAccessFile(input, &file));
StringPiece result;
char buffer[10] = {0};
Status status = file->Read(0, 10, &result, buffer);
RandomAccessFile是一个抽象类,不同文件系统中对文件的访问都需要实现这个抽象类。而不同的文件系统又需要实现FileSystem抽象类,总的来说就是文件系统和文件的访问方式都需要实现。
Env* env_是用来配置文件系统环境变脸的类,NewRandomAccessFile()函数是用来创建新的RandomAccessFile对象的方法。以下是env.cc中NewRandomAccessFile的实现
Status Env::NewRandomAccessFile(const string& fname,
std::unique_ptr<RandomAccessFile>* result) {
FileSystem* fs;
TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs));
return fs->NewRandomAccessFile(fname, result);
}
可以看到env中本质也是使用FileSystem类的NewRandomAccessFile()函数,FileSystem在tensorflow中的file_sys.h中声明,同样RandomAccessFile也在tensorflow中的file_sys.h中声明。
最后以S3文件系统举例,文件系统都需要继承FileSystem类。这样就可以实例化到底是使用的哪个文件系统进行上面的一系列流程。
class S3FileSystem : public FileSystem {
public:
S3FileSystem();
~S3FileSystem();
Status NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
Status NewRandomAccessFile(const string& fname,
std::unique_ptr<RandomAccessFile>* result,
bool use_multi_part_download);
Status NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
文章的开头举得例子是使用http文件系统来加载MINIST数据集,如果要使用DAOS则把url改成如下面代码所示:
dfs_url = "daos://TEST_POOL/TEST_CONT/" # This the path you'll be using to load and access the dataset
pwd = !pwd
posix_url = pwd[0] + "/tests/test_dfs/"
images = dfs_url + "train.gz"
labels = dfs_url + "train_labels.gz"
d_train = tfio.IODataset.from_mnist(
images,
labels
)
那么又是如何通过URL来区分不同文件系统的文件呢,我准备留到下一篇文章来讲解。
总结
以上就是使用python调用动态链接库的全过程,通过对源码的阅读能够大致了解tensorflow中文件系统的设计模式和原理。需要注意的是在python中读取libtensorflow_io.so时将得到的Module的name参数赋值为了core_ops,而libtensorflow_io.so中也有一个core_ops模块,这两个是没有什么关系的。