tensorflow dataset模块

写在前面:

Python 3.6.5-debug

Apple LLVM version 10.0.1 (clang-1001.0.46.4)

Tensorflow 1.10.1-debug

1.Graph构图

Graph构图过程,顾名思义,主要讲session run函数调用之前,将图中的每个节点都构建入graph内执行步骤。

Dataset 分为两种,1. 真正获取数据的dataset;2. 对1中数据做改变的dataset;dataset的op定义都在core/kernels/data/文件夹下,1中的dataset统一存放在reader_dataset_ops.cc中例如TextLineDataset, FixedLengthRecordDataset, TFRecordDataset;2中的dataset在上述文件夹下,例如ShuffleDataset, RepeatDataset, MapDataset, TensorDataset等。

下图是一个关于dataset的小例子:

def ReadRecordsByIter() :

    dataset = tf.data.TFRecordDataset('image.tfrecords')
    #image_dataset = image_dataset.shuffle(1)
    dataset = dataset.repeat(2)
    iter = dataset.make_one_shot_iterator();
    one = iter.get_next()

    with tf.Session() as sess:
        for i in range(2):
            print(sess.run(one).decode())

例子很简单,1. 生成了一个TFRecordDataset;2. 调用该dataset生成了RepeatDataset;3. 调用dataset的make_one_shot_iter生成了iter对象;4.调用iter的get_next()的方法得到了一个one对象。现在剖析背后到底做了些什么 ? 首先TFRecordDataset继承于Dataset,构造函数显示,就是构造了一个TFRecordDataset对象;接下来看第二行。image_dataset是TFRecordDataset类型,调用了repeat方法,该方法在基类里面,函数如下:

  def repeat(self, count=None):
    return RepeatDataset(self, count)

 该方法通过基类,生成了一个RepeatDataset的方法,该方法将self传入RepeatDatset的构造函数,如下:

class RepeatDataset(Dataset):
  def __init__(self, input_dataset, count):
    super(RepeatDataset, self).__init__()
    self._input_dataset = input_dataset

第二行显示,他将调用它的dataset,做成了self._input_dataset,这样就记录下了repeat的input为TFRecordDataset。根据这个办法,就可以将input的dataset记录下来。

解析下一行,该行调用了make_one_shot_iterator,生成了一个OneShotIterator对象.

  def make_one_shot_iterator(self):
    @function.Defun(capture_by_value=True)
    def _make_dataset():
      return self._as_variant_tensor()  # pylint: disable=protected-access
    try:
      _make_dataset.add_to_graph(ops.get_default_graph())

最后一句add_to_graph()是重点,他会调用到4行将调用self的as_variant_tensor()方法。将self注册到图中。self是最后一层dataset,在我们的例子中是RepeatDataset.

  def _as_variant_tensor(self):
    return gen_dataset_ops.repeat_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        count=self._count,
        **flat_structure(self))

该函数又会调用input_dataset的as_variant_tensor来完成将所有的dataset注册到图中。

该逻辑都是在Python中完成,上述逻辑会在内存中生成GraphDef中,该图中不涉及到任何计算,只规定输入输出的shape,规定数据的流向。真正执行过程是在session.run()中执行的。接下的逻辑就是在session.run()中的。

最后一句的get_next()函数是Iterator类调用的,返回函数如下:

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)

2. Dataset流程 

Dataset通常使用Iterator拿出数据,所以需要一个从DatasetOp到Dataset再到Iterator的构建过程才能等到Iterator,当我们拿到Iter的时候就可以通过Iterator来获取数据了。所以从Op到Iter通常需要三步,第一步生成Dataset,第二步生成Iterator,第三部循环获取Tensor,首先介绍一下Dataset的结构图,如下图:

2.1 dataset结构图

 

每个dataset,对应有一个datasetop类,每个datasetop类中有对应的dataset类,dataset类中又嵌套着Iterator类,如下图所示,简单举ShuffleDataasetOpBase、TextLenDatasetOp和TFRecordDatasetOp三个op例子。

 

其中op类中主要方法为MakeDataset方法,该方法主要创建Dataset类;

DatasetBase中有MakeIterator类,该方法主要创建Iterator

 

Iterator每个派生类实现了GetNextInternal方法来做不同的事情,因为Dataset分为root dataset和normal dataset,所以GetNextInternal的实现也大致分为两种,root的实现是从reader_中获取数据,normal的实现是从input iter递归调用GetNextInternal函数,得到数据后对数据进行缓存、变形、甄选等操作。

先进行第一步,从Op中得到Dataset。

 

2.2 MakeDataset

首先,根据注册在图中的op,生成dataset。生成下图op的方式是异步的,先将0 input的节点抛入ready队列,处理完该节点之后,将该节点作为input的节点的input pending size - 1,当后续节点input pending size减为0时,由最后一个input节点激活该节点,将节点放入ready队列中,处理,以此循环,直到最后一个节点处理完毕。

上图中所有的op(op_kernel)都会实现Compute或者ComputeAsync方法。在DatasetOpKernel中实现了该方法如下图: 

  void DatasetOpKernel::Compute(OpKernelContext* ctx) {
    DatasetBase* dataset = nullptr;
    MakeDataset(ctx, &dataset);
    if (ctx->status().ok()) {
      Tensor* output = nullptr;
      OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
      OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
    }
  }
 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
                                         DatasetBase** output) {
    DatasetBase* input;
    OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
    MakeDataset(ctx, input, output);
  }

Datasetop大多数子类都共用同一个Compute方法就是该方法。

该方法会调用到252行的MakeDataset方法,因为能执行Compute的op都是input已经准备好了的。所以可以从ctx里get到input,不过input的类型都是vector<tensor>的,其实该Tensor利用其中真正存储类型为void*,Get方法为又被强转成T(模板参数)类型的方式,可以在Tensor中存储一个Dataset指针,通过GetDatasetFromVariant方法利用input的Tensor可以得到input的dataset,Get方法如下:

Status GetDatasetFromVariantTensor(const Tensor& tensor,
                                     DatasetBase** out_dataset) {
    const Variant& variant = tensor.scalar<Variant>()();
    const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
    *out_dataset = wrapper->get();
    
    return Status::OK();
  }

2.2.1 Normal Dataset

然后256行的MakeDataset是每个子op特有的dataset,根据op和input、ctx来构造处output的dataset,假设该Op是RepeatDatasetOp,代码如下图:

class RepeatDatasetOp : public UnaryDatasetOpKernel {
   public:
    explicit RepeatDatasetOp(OpKernelConstruction* ctx)
        : UnaryDatasetOpKernel(ctx) {}

   protected:
    void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                     DatasetBase** output) override {
      int64 count;
      OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
      *output = new Dataset(ctx, count, input);
    }

   private:
    class Dataset : public GraphDatasetBase {
     public:
      Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
          : GraphDatasetBase(ctx), count_(count), input_(input) {
        input_->Ref();
      }

MakeDataset会根据不同的Op返回不同的Dataset,并且在构造Dataset的时候会保留input的dataset,并将input的dataset的引用计数+1。这样链式调用下来我们就可以得到一个关于dataset的input调用链。

之后我们再返回上上图的248行,该行是将该op的dataset封装成output的Tensor,作为下一个op的input,具体实现如下,该方法是全局方法,使用variantWrapper包裹dataset后放入Tensor中,因为Tensor的buf_的派生类是Buffer<T>,实际存储是void*型的,所以道理上什么变量都可以放入其中,当然除了有Store方法外,当然还有Get方法。

  Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
    tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
    return Status::OK();
  }

2.2.2 Root Dataset

RootDatasetOp直接继承于DatasetOpKernel,并每个子类自己实现了两参的MakeDataset方法,我们以TFRecordDatasetOp为例。

  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
      const Tensor* filenames_tensor;
      OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
      OP_REQUIRES(
          ctx, filenames_tensor->dims() <= 1,
          errors::InvalidArgument("`filenames` must be a scalar or a vector."));

      std::vector<string> filenames;
      filenames.reserve(filenames_tensor->NumElements());
      for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
        filenames.push_back(filenames_tensor->flat<string>()(i));
      }

      string compression_type;
      OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
                                                      &compression_type));

      int64 buffer_size = -1;
      OP_REQUIRES_OK(
          ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
      OP_REQUIRES(ctx, buffer_size >= 0,
                  errors::InvalidArgument(
                      "`buffer_size` must be >= 0 (0 == no buffering)"));

      *output =
          new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
    }

2.3 MakeIterator

上图中我们构建好了链式dataset,因为我们Python代码中写的是one = iter->get_next(),和session->run(one),这里需要注意一点,get_next()函数返回的是GetNextOp而不是get_next()返回的结果。

接下来我们继续看如何构造链式Iterator:

同样,当我们执行完底层的dataset的链式结构之后,会执行到MakeOneShotIterator,这里会得到一个OneShotIteratorOp,同理异步也会调用到这个类的Compute或者ComputeAsync方法。左边一列函数就是主要是在做MakeIterator,

该Op的主要目的就是创建一个包含链式iter的IteratorResource。

其中,先从ctx拿到resource manage,然后调用mgr的LookupOrCreate方法,该mgr主要保存了从名字到resource的映射,如果有重名的resource,将直接取出。该函数如下:

 TF_RETURN_IF_ERROR(
          ctx->resource_manager()->LookupOrCreate<IteratorResource>(
              cinfo->container(), cinfo->name(), iterator,
              [lib, this, &flib_def, &pflr](IteratorResource** ret)
                  EXCLUSIVE_LOCKS_REQUIRED(mu_) {
                    *ret = new IteratorResource(
                        output_dtypes_, output_shapes_, graph_def_version_,
                        nullptr, std::move(flib_def), std::move(pflr), lib);
                    return Status::OK();
                  }));

该函数接受一个二级指针Iterator,一个name,和一个如果没有找到的create的方法,根据这个lambda表达式创建一个resource,并且放入ResourceManage内;

然后根据return_value,这个值怎么来的前面已经讲过了,不再叙述。这块的详细代码如下:

DatasetBase* dataset;
      TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
      IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
      std::unique_ptr<IteratorBase> iter;
      TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
      TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));

2.4 MakeIterator

上图所示,我们拿到了最上层的dataset,并由其构建Iterator序列,假设最上层的dataset为RepeatDataset,MakeIterator是基类特有函数,每个派生类的实现了MakeIterator内的MakeIteratorInternal函数,具体代码如下:

Status MakeIterator(IteratorContext* ctx, const string& prefix,
                        std::unique_ptr<IteratorBase>* iterator) const {
      *iterator = MakeIteratorInternal(prefix);
      return (*iterator)->Initialize(ctx);
    }

该函数首先调用了MakeIteratorInternal生成了Iterator,然后再条用iterator的Init函数。MakeIteratorInternal和Init函数如下:

      std::unique_ptr<IteratorBase> MakeIteratorInternal(
          const string& prefix) const override {
        if (count_ < 0) {
						// make forever iterator .
        } else if (count_ == 0) {
						//make empty Iterator .
        } else {
          return std::unique_ptr<IteratorBase>(new FiniteIterator(
              {this, strings::StrCat(prefix, "::FiniteRepeat")}));
        }
      }
      class FiniteIterator : public DatasetIterator<Dataset> {
       public:
        explicit FiniteIterator(const Params& params)
            : DatasetIterator<Dataset>(params), i_(0) {}

        Status Initialize(IteratorContext* ctx) override {
          return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
        }

代码浅显易懂,就不过多叙述了。如上图所示这样就完成了Iter的链式构造。iter之后由set_iterator函数放入IteratorResource中。并保存在ctx内,由将OneShotInteratorOp作为输入的的IteratorGetNextOp调用。

3 数据流向

其实这部分就是上节封装的IteratorResource中顶层dataset调用getnext方法,获取数据的过程。

ComputeAsync方法同理从ctx中ResourceManager中使用LookupResource接口得到上个op放入的IteratorResource;如下图。

void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
      IteratorResource* iterator;
      OP_REQUIRES_OK_ASYNC(
          ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
  
  		//use thread pool 
  Status s =
                iterator->GetNext(&iter_ctx, &components, &end_of_sequence);

调用该IteratorResource的GetNext的方法,该方法会调用放入首层dataset的Iterator的GetNext函数,如下图:

Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                   bool* end_of_sequence) {
      std::shared_ptr<IteratorBase> captured_iterator(iterator_);
        return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
    }

该函数是基类实现的方法,并且每个派生类没有实现,其中GetNextInternal才是每个派生类实现的方法。

    Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                   bool* end_of_sequence) final {
      tracing::ScopedActivity activity(params_.prefix);
      Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
      return s;
    }

GetNextInternal方法如下,我们用RepeatDataset的Iterator举例,normal Iterator会调用它依赖的iter的GetNext()这样就形成了循环。

 Status GetNextInternal(IteratorContext* ctx,
                               std::vector<Tensor>* out_tensors,
                               bool* end_of_sequence) override {
          mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
          if (!input_impl_) {
            *end_of_sequence = true;
            return Status::OK();
          }
          while (i_ < dataset()->count_) {
            TF_RETURN_IF_ERROR(
                input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
            if (!*end_of_sequence) {
              return Status::OK();
            }
            ++i_;
            TF_RETURN_IF_ERROR(
                dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
          }
          *end_of_sequence = true;
          input_impl_.reset();
          return Status::OK();
        }

但是root Iterator不同,他是最底层的Iter,没有下层了,所以下层的GetNext函数就有另外的实现方式,我们以最常用的TFRecordDataset举例,它的Iterator的GetNext函数如下:

 do {
            // We are currently processing a file, so try to read the next line.
            if (buffered_input_stream_) {
              string line_contents;
              Status s = buffered_input_stream_->ReadLine(&line_contents);

              if (s.ok()) {
                // Produce the line as output.
                Tensor line_tensor(ctx->allocator({}), DT_STRING, {});
                line_tensor.scalar<string>()() = line_contents;
                out_tensors->emplace_back(std::move(line_tensor));
                *end_of_sequence = false;
                return Status::OK();
              } else if (!errors::IsOutOfRange(s)) {
                // Report non-EOF errors to the caller.
                return s;
              }
              // We have reached the end of the current file, so maybe
              // move on to next file.
              ResetStreamsLocked();
              ++current_file_index_;
            }

            // Iteration ends when there are no more files to process.
            if (current_file_index_ == dataset()->filenames_.size()) {
              *end_of_sequence = true;
              return Status::OK();
            }

            TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
          } while (true);

它依赖了reader_去读取它的内容。

4. 常用的Dataset

4.1 RepeatDataset

有三种Iter,1. ForeverIter,2. EmptyIter, 3. FiniteIter,实现都非常简单,第一种就写while(true),第二种直接返回,第三种写当前批数小于设置批数。我们以第二种为例:

Status GetNextInternal(IteratorContext* ctx,
                               std::vector<Tensor>* out_tensors,
                               bool* end_of_sequence) override {
          mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
          if (!input_impl_) {
            *end_of_sequence = true;
            return Status::OK();
          }
          while (i_ < dataset()->count_) {
            TF_RETURN_IF_ERROR(
                input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
            if (!*end_of_sequence) {
              return Status::OK();
            }
            ++i_;
            TF_RETURN_IF_ERROR(
                dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
          }
          *end_of_sequence = true;
          input_impl_.reset();
          return Status::OK();
        }

4.2 ShuffleDataset

如图,该Iter自己有buffer,他先通过预先设定的buffersize填充buffer,如下图,需要注意的点是该input_impl是运行时定义的。

while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
            if (ctx->env()->NowMicros() >
                ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
              num_log_entries++;
              LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
                        << num_elements_ << " of "
                        << this->dataset()->buffer_size_;
            }
            std::vector<Tensor> input_element;
            bool end_of_input_sequence = false;
            while (this->dataset()->count_ == -1 ||
                   epoch_ < this->dataset()->count_) {
              TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
                                                      &end_of_input_sequence));
              if (!end_of_input_sequence) {
                first_call = false;
                break;
              }

一直会从input里GetNext直到buffer填满,然后随机一个到index,然后将其返回。

int64 offset =
                Random() % (slices_.front()->end - slices_.front()->start);
            int64 index =
                (slices_.front()->start + offset) % this->dataset()->buffer_size_;
            *out_tensors = std::move(buffer_[index]);
            std::swap(
                buffer_[index],
                buffer_[slices_.front()->start % this->dataset()->buffer_size_]);
            slices_.front()->start++;
            num_elements_--;

4.3 MapDataset

比较简单,就是将inputGetNext出来的数据再经过一次变换。如下图:

        Status GetNextInternal(IteratorContext* ctx,
                               std::vector<Tensor>* out_tensors,
                               bool* end_of_sequence) override {

          std::vector<Tensor> args;
          TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
          if (*end_of_sequence) {
            return Status::OK();
          }

          Status s =
              dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
          if (errors::IsOutOfRange(s)) {

            *end_of_sequence = true;
            return Status::OK();
          } else {
            return s;
          }
        }

4.4 PrefetchDataset

该Dataset,多创建了一个线程填充Buffer,另外一个线程去消费Buffer。GetNextInternal函数如下:

Status GetNextInternal(IteratorContext* ctx,
                               std::vector<Tensor>* out_tensors,
                               bool* end_of_sequence) override {
          {
            mutex_lock l(mu_);
            TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));

            while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
                   auto_tuner_.buffer_limit() != 0) {
              auto_tuner_.RecordEmpty();
              cond_var_.wait(l);
            }

            if (cancelled_) {
              return errors::Cancelled(
                  "PrefetchDatasetOp::Dataset::Iterator::GetNext");
            }

            if (!buffer_.empty()) {
              return Consume(out_tensors, end_of_sequence);
            }

另外一个线程在:

  BufferElement buffer_element;
            buffer_element.status = input_impl_->GetNext(
                ctx, &buffer_element.value, &end_of_sequence);

 

 

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页