深入理解 Tensorflow :如何读训练数据

这里写图片描述以下分析来自 tensorflow slim 库代码精简之后


dataset = dataset_factory.get_dataset(dataset_name, dataset_split_name, dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=num_readers,
          common_queue_capacity=20 * batch_size,
          common_queue_min=10 * batch_size)

    key, data = parallel_reader.parallel_read(
        dataset.data_sources,
        reader_class=dataset.reader,
        num_epochs=num_epochs,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs,
        shuffle=shuffle,
        capacity=common_queue_capacity,
        seed=seed,
        scope=scope)

        data_files = get_data_files(dataset.data_sources)
        # 这里对数据源创建一个 FIFO 队列
        filename_queue = tf_input.string_input_producer(data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed, name='filenames')
            input_tensor = ops.convert_to_tensor(data_files, dtype=dtypes.string)
            if shuffle:
                input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
            # 最多读 num_epochs 次,超过就会抛 OutOfRangeError,当 num_epochs 为 None 时,可以无限次读
            input_tensor = limit_epochs(input_tensor, num_epochs)
            element_shape = input_tensor.shape[1:].merge_with([])
            q = data_flow_ops.FIFOQueue(capacity=32, dtypes=[input_tensor.dtype.base_dtype],
                                shapes=[element_shape], shared_name=shared_name, name=name)
            enq = q.enqueue_many([input_tensor])
            queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
            return q

        if shuffle:
            common_queue = data_flow_ops.RandomShuffleQueue(
                capacity=capacity,
                min_after_dequeue=min_after_dequeue,
                dtypes=dtypes,
                seed=seed,
                name='common_queue')
        else:
            common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes, name='common_queue')

        reader_kwargs = reader_kwargs or {}
        enqueue_ops = []
        for reader in [reader_class(**reader_kwargs) for _ in range(num_readers)]
          enqueue_ops.append(common_queue.enqueue(reader.read(queue)))

        queue_runner.add_queue_runner(queue_runner.QueueRunner(common_queue, enqueue_ops))
        return common_queue.dequeue(name=name)

    items = dataset.decoder.list_items()
    tensors = dataset.decoder.decode(data, items)
    items_to_tensors[record_key] = key

    return super(DatasetDataProvider, self).__init__(items_to_tensors=items_to_tensors, num_samples=dataset.num_samples)

由上分析可见,创建了两组队列
1. FIFOQueue 队列,从 data_files 读取数据,写入该队列尾部
2. num_readers 个 FIFOQueue 或 RandomShuffleQueue 队列,从 FIFOQueue 队列头读数据

其中 add_queue_runner 将各个 queue_runner 加入 ops.GraphKeys.QUEUE_RUNNERS,
当训练开始的时候, 会调用 start_queue_runners,它会为 enqueue_ops 中的每个
操作启动一个线程。 具体参考 python/training/queue_runner_impl.py

还有一点需要注意的,

  1. 队列的实现是 cpp 来实现的,
  2. queue_runner 是 python 的线程。
  3. TFRecordReader 和 TFExampleDecoder 核心都是 cpp 实现的

备注:关于队列部分和 TFRecordReader,我将开专门的文章分析。

这个实现有什么问题?

  1. FIFOQueue 队列的 capacity 太小只要 32,因此,瓶颈可能在 FIFOQueue 队列
  2. 队列都是本机内的,无法跨主机,而事实上对于一个大型深度学习系统来说,数据一般不可能在同一台机器。跨机器访问是刚需
  3. 当然,如果程序中断,网络中断,必须从头开始,因此可靠性不够

改进,将 FIFOQueue 队列改为一个类似 kafka 的分布式队列即可

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值