TF2.0 - Data Pipelines性能优化

在学习TF2.0官方文档时的总结,跳转https://www.tensorflow.org/guide/data_performance#top_of_page

import tensorflow as tf 
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import time 

人工模拟一个数据读取过程:打开文件的时间和读取文件的时间
每个epoch的file打开、读取和训练是串行的,所以整个过程多需时间更长

class ArtificalDataset(tf.data.Dataset):
    def _generator(num_samples):
        time.sleep(0.03)   # 打开文件

        for sample_idex in range(num_samples):
            time.sleep(0.015)   # 读取数据
            yield(sample_idex,)
        
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=tf.dtypes.int64,
            output_shapes=(1,),
            args=(num_samples,)
        )

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.015)   # 每个训练
    tf.print('Execution time: ', time.perf_counter() - start_time)

benchmark(ArtificalDataset())

对上述的数据方式进行改进,将数据预取和训练步骤的执行进行重叠
tf.data.Dataset.prefetch()使用了一个后台进程和一个内部的缓存提前从数据集中获取数据,每次获取的数据必须大于等于单个训练步骤所需的数据,可以通过tf.data.experimental.AUTOTUNE可以自动设置参数。

benchmark(
    ArtificalDataset().prefetch(tf.data.experimental.AUTOTUNE)
)

当使用一个管道进行远程读取数据时就会产生I/O瓶颈。tf.data.Dataset.interleave()可以将数据加载并行化。
cycle_length参数用用来控制重叠执行的数量,num_paraller_calls用来控制并行程度

benchmark(
    tf.data.Dataset.range(2).interleave(ArtificalDataset)
)

使用参数num_paraller_calls参数并行加载多个数据集

benchmark(
    tf.data.Dataset.range(2).interleave(
        ArtificalDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)

tf.data.Dataset.map()可以通过自定义函数对每个数据进行预处理,这一过程也可以并行

def mapped_function(s):
    tf.py_function(lambda : time.sleep(0.03), [], ())   # 模仿预处理
    return s

benchmark(
    ArtificalDataset().map(mapped_function)
)

现在对该过程进行并行

benchmark(
    ArtificalDataset().map(
        mapped_function,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)

tf.data.Dataset.cache()可以将数据缓存在本地或者内存

benchmark(
    ArtificalDataset().map(
        mapped_function
    ).cache(),
    5
)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值