写在前面:
Python 3.6.5-debug
Apple LLVM version 10.0.1 (clang-1001.0.46.4)
Tensorflow 1.10.1-debug
1.Graph构图
Graph构图过程,顾名思义,主要讲session run函数调用之前,将图中的每个节点都构建入graph内执行步骤。
Dataset 分为两种,1. 真正获取数据的dataset;2. 对1中数据做改变的dataset;dataset的op定义都在core/kernels/data/文件夹下,1中的dataset统一存放在reader_dataset_ops.cc中例如TextLineDataset, FixedLengthRecordDataset, TFRecordDataset;2中的dataset在上述文件夹下,例如ShuffleDataset, RepeatDataset, MapDataset, TensorDataset等。
下图是一个关于dataset的小例子:
def ReadRecordsByIter() :
dataset = tf.data.TFRecordDataset('image.tfrecords')
#image_dataset = image_dataset.shuffle(1)
dataset = dataset.repeat(2)
iter = dataset.make_one_shot_iterator();
one = iter.get_next()
with tf.Session() as sess:
for i in range(2):
print(sess.run(one).decode())
例子很简单,1. 生成了一个TFRecordDataset;2. 调用该dataset生成了RepeatDataset;3. 调用dataset的make_one_shot_iter生成了iter对象;4.调用iter的get_next()的方法得到了一个one对象。现在剖析背后到底做了些什么 ? 首先TFRecordDataset继承于Dataset,构造函数显示,就是构造了一个TFRecordDataset对象;接下来看第二行。image_dataset是TFRecordDataset类型,调用了repeat方法,该方法在基类里面,函数如下:
def repeat(self, count=None):
return RepeatDataset(self, count)
该方法通过基类,生成了一个RepeatDataset的方法,该方法将self传入RepeatDatset的构造函数,如下:
class RepeatDataset(Dataset):
def __init__(self, input_dataset, count):
super(RepeatDataset, self).__init__()
self._input_dataset = input_dataset
第二行显示,他将调用它的dataset,做成了self._input_dataset,这样就记录下了repeat的input为TFRecordDataset。根据这个办法,就可以将input的dataset记录下来。
解析下一行,该行调用了make_one_shot_iterator,生成了一个OneShotIterator对象.
def make_one_shot_iterator(self):
@function.Defun(capture_by_value=True)
def _make_dataset():
return self._as_variant_tensor() # pylint: disable=protected-access
try:
_make_dataset.add_to_graph(ops.get_default_graph())
最后一句add_to_graph()是重点,他会调用到4行将调用self的as_variant_tensor()方法。将self注册到图中。self是最后一层dataset,在我们的例子中是RepeatDataset.
def _as_variant_tensor(self):
return gen_dataset_ops.repeat_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
count=self._count,
**flat_structure(self))
该函数又会调用input_dataset的as_variant_tensor来完成将所有的dataset注册到图中。
该逻辑都是在Python中完成,上述逻辑会在内存中生成GraphDef中,该图中不涉及到任何计算,只规定输入输出的shape,规定数据的流向。真正执行过程是在session.run()中执行的。接下的逻辑就是在session.run()中的。
最后一句的get_next()函数是Iterator类调用的,返回函数如下:
return sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self._output_types,
gen_dataset_ops.iterator_get_next(
self._iterator_resource,
output_types=nest.flatten(
sparse.as_dense_types(
self._output_types,
self._output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(
self._output_shapes,
self._output_classes)),
name=name)), self._output_types,
self._output_shapes, self._output_classes)
2. Dataset流程
Dataset通常使用Iterator拿出数据,所以需要一个从DatasetOp到Dataset再到Iterator的构建过程才能等到Iterator,当我们拿到Iter的时候就可以通过Iterator来获取数据了。所以从Op到Iter通常需要三步,第一步生成Dataset,第二步生成Iterator,第三部循环获取Tensor,首先介绍一下Dataset的结构图,如下图:
2.1 dataset结构图
每个dataset,对应有一个datasetop类,每个datasetop类中有对应的dataset类,dataset类中又嵌套着Iterator类,如下图所示,简单举ShuffleDataasetOpBase、TextLenDatasetOp和TFRecordDatasetOp三个op例子。
其中op类中主要方法为MakeDataset方法,该方法主要创建Dataset类;
DatasetBase中有MakeIterator类,该方法主要创建Iterator
Iterator每个派生类实现了GetNextInternal方法来做不同的事情,因为Dataset分为root dataset和normal dataset,所以GetNextInternal的实现也大致分为两种,root的实现是从reader_中获取数据,normal的实现是从input iter递归调用GetNextInternal函数,得到数据后对数据进行缓存、变形、甄选等操作。
先进行第一步,从Op中得到Dataset。
2.2 MakeDataset
首先,根据注册在图中的op,生成dataset。生成下图op的方式是异步的,先将0 input的节点抛入ready队列,处理完该节点之后,将该节点作为input的节点的input pending size - 1,当后续节点input pending size减为0时,由最后一个input节点激活该节点,将节点放入ready队列中,处理,以此循环,直到最后一个节点处理完毕。
上图中所有的op(op_kernel)都会实现Compute或者ComputeAsync方法。在DatasetOpKernel中实现了该方法如下图:
void DatasetOpKernel::Compute(OpKernelContext* ctx) {
DatasetBase* dataset = nullptr;
MakeDataset(ctx, &dataset);
if (ctx->status().ok()) {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
}
}
void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
MakeDataset(ctx, input, output);
}
Datasetop大多数子类都共用同一个Compute方法就是该方法。
该方法会调用到252行的MakeDataset方法,因为能执行Compute的op都是input已经准备好了的。所以可以从ctx里get到input,不过input的类型都是vector<tensor>的,其实该Tensor利用其中真正存储类型为void*,Get方法为又被强转成T(模板参数)类型的方式,可以在Tensor中存储一个Dataset指针,通过GetDatasetFromVariant方法利用input的Tensor可以得到input的dataset,Get方法如下:
Status GetDatasetFromVariantTensor(const Tensor& tensor,
DatasetBase** out_dataset) {
const Variant& variant = tensor.scalar<Variant>()();
const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
*out_dataset = wrapper->get();
return Status::OK();
}
2.2.1 Normal Dataset
然后256行的MakeDataset是每个子op特有的dataset,根据op和input、ctx来构造处output的dataset,假设该Op是RepeatDatasetOp,代码如下图:
class RepeatDatasetOp : public UnaryDatasetOpKernel {
public:
explicit RepeatDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
*output = new Dataset(ctx, count, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
: GraphDatasetBase(ctx), count_(count), input_(input) {
input_->Ref();
}
MakeDataset会根据不同的Op返回不同的Dataset,并且在构造Dataset的时候会保留input的dataset,并将input的dataset的引用计数+1。这样链式调用下来我们就可以得到一个关于dataset的input调用链。
之后我们再返回上上图的248行,该行是将该op的dataset封装成output的Tensor,作为下一个op的input,具体实现如下,该方法是全局方法,使用variantWrapper包裹dataset后放入Tensor中,因为Tensor的buf_的派生类是Buffer<T>,实际存储是void*型的,所以道理上什么变量都可以放入其中,当然除了有Store方法外,当然还有Get方法。
Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
return Status::OK();
}
2.2.2 Root Dataset
RootDatasetOp直接继承于DatasetOpKernel,并每个子类自己实现了两参的MakeDataset方法,我们以TFRecordDatasetOp为例。
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx, buffer_size >= 0,
errors::InvalidArgument(
"`buffer_size` must be >= 0 (0 == no buffering)"));
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}
2.3 MakeIterator
上图中我们构建好了链式dataset,因为我们Python代码中写的是one = iter->get_next(),和session->run(one),这里需要注意一点,get_next()函数返回的是GetNextOp而不是get_next()返回的结果。
接下来我们继续看如何构造链式Iterator:
同样,当我们执行完底层的dataset的链式结构之后,会执行到MakeOneShotIterator,这里会得到一个OneShotIteratorOp,同理异步也会调用到这个类的Compute或者ComputeAsync方法。左边一列函数就是主要是在做MakeIterator,
该Op的主要目的就是创建一个包含链式iter的IteratorResource。
其中,先从ctx拿到resource manage,然后调用mgr的LookupOrCreate方法,该mgr主要保存了从名字到resource的映射,如果有重名的resource,将直接取出。该函数如下:
TF_RETURN_IF_ERROR(
ctx->resource_manager()->LookupOrCreate<IteratorResource>(
cinfo->container(), cinfo->name(), iterator,
[lib, this, &flib_def, &pflr](IteratorResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
nullptr, std::move(flib_def), std::move(pflr), lib);
return Status::OK();
}));
该函数接受一个二级指针Iterator,一个name,和一个如果没有找到的create的方法,根据这个lambda表达式创建一个resource,并且放入ResourceManage内;
然后根据return_value,这个值怎么来的前面已经讲过了,不再叙述。这块的详细代码如下:
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iter;
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
2.4 MakeIterator
上图所示,我们拿到了最上层的dataset,并由其构建Iterator序列,假设最上层的dataset为RepeatDataset,MakeIterator是基类特有函数,每个派生类的实现了MakeIterator内的MakeIteratorInternal函数,具体代码如下:
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
return (*iterator)->Initialize(ctx);
}
该函数首先调用了MakeIteratorInternal生成了Iterator,然后再条用iterator的Init函数。MakeIteratorInternal和Init函数如下:
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
// make forever iterator .
} else if (count_ == 0) {
//make empty Iterator .
} else {
return std::unique_ptr<IteratorBase>(new FiniteIterator(
{this, strings::StrCat(prefix, "::FiniteRepeat")}));
}
}
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
: DatasetIterator<Dataset>(params), i_(0) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
代码浅显易懂,就不过多叙述了。如上图所示这样就完成了Iter的链式构造。iter之后由set_iterator函数放入IteratorResource中。并保存在ctx内,由将OneShotInteratorOp作为输入的的IteratorGetNextOp调用。
3 数据流向
其实这部分就是上节封装的IteratorResource中顶层dataset调用getnext方法,获取数据的过程。
ComputeAsync方法同理从ctx中ResourceManager中使用LookupResource接口得到上个op放入的IteratorResource;如下图。
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
//use thread pool
Status s =
iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
调用该IteratorResource的GetNext的方法,该方法会调用放入首层dataset的Iterator的GetNext函数,如下图:
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
}
该函数是基类实现的方法,并且每个派生类没有实现,其中GetNextInternal才是每个派生类实现的方法。
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
return s;
}
GetNextInternal方法如下,我们用RepeatDataset的Iterator举例,normal Iterator会调用它依赖的iter的GetNext()这样就形成了循环。
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
return Status::OK();
}
++i_;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
*end_of_sequence = true;
input_impl_.reset();
return Status::OK();
}
但是root Iterator不同,他是最底层的Iter,没有下层了,所以下层的GetNext函数就有另外的实现方式,我们以最常用的TFRecordDataset举例,它的Iterator的GetNext函数如下:
do {
// We are currently processing a file, so try to read the next line.
if (buffered_input_stream_) {
string line_contents;
Status s = buffered_input_stream_->ReadLine(&line_contents);
if (s.ok()) {
// Produce the line as output.
Tensor line_tensor(ctx->allocator({}), DT_STRING, {});
line_tensor.scalar<string>()() = line_contents;
out_tensors->emplace_back(std::move(line_tensor));
*end_of_sequence = false;
return Status::OK();
} else if (!errors::IsOutOfRange(s)) {
// Report non-EOF errors to the caller.
return s;
}
// We have reached the end of the current file, so maybe
// move on to next file.
ResetStreamsLocked();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
它依赖了reader_去读取它的内容。
4. 常用的Dataset
4.1 RepeatDataset
有三种Iter,1. ForeverIter,2. EmptyIter, 3. FiniteIter,实现都非常简单,第一种就写while(true),第二种直接返回,第三种写当前批数小于设置批数。我们以第二种为例:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
return Status::OK();
}
++i_;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
*end_of_sequence = true;
input_impl_.reset();
return Status::OK();
}
4.2 ShuffleDataset
如图,该Iter自己有buffer,他先通过预先设定的buffersize填充buffer,如下图,需要注意的点是该input_impl是运行时定义的。
while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
<< num_elements_ << " of "
<< this->dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
bool end_of_input_sequence = false;
while (this->dataset()->count_ == -1 ||
epoch_ < this->dataset()->count_) {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence));
if (!end_of_input_sequence) {
first_call = false;
break;
}
一直会从input里GetNext直到buffer填满,然后随机一个到index,然后将其返回。
int64 offset =
Random() % (slices_.front()->end - slices_.front()->start);
int64 index =
(slices_.front()->start + offset) % this->dataset()->buffer_size_;
*out_tensors = std::move(buffer_[index]);
std::swap(
buffer_[index],
buffer_[slices_.front()->start % this->dataset()->buffer_size_]);
slices_.front()->start++;
num_elements_--;
4.3 MapDataset
比较简单,就是将inputGetNext出来的数据再经过一次变换。如下图:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::vector<Tensor> args;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
}
Status s =
dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
if (errors::IsOutOfRange(s)) {
*end_of_sequence = true;
return Status::OK();
} else {
return s;
}
}
4.4 PrefetchDataset
该Dataset,多创建了一个线程填充Buffer,另外一个线程去消费Buffer。GetNextInternal函数如下:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
cond_var_.wait(l);
}
if (cancelled_) {
return errors::Cancelled(
"PrefetchDatasetOp::Dataset::Iterator::GetNext");
}
if (!buffer_.empty()) {
return Consume(out_tensors, end_of_sequence);
}
另外一个线程在:
BufferElement buffer_element;
buffer_element.status = input_impl_->GetNext(
ctx, &buffer_element.value, &end_of_sequence);