tensorflow2之数据管道Dataset

原则

  • 数据量不大,直接入内存计算即可
  • 数据量过大,无法一次性载入内存,需要分批读入:tf.data的API构建数据输入管道

构建

  • numpy

      ds = tf.data.Dataset.from_tensor_slices((['train_x'], ['train_y']))
    
  • pandas:同上df.to_dict('list')

  • generator

    def generator():
        for features, labels in ds:
             yield (features, labels)
    
    ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.int32))
    
  • csv

      tf.data.experimental.make_csv_dataset(file_pattern = ['x.csv', 'xx.csv'], batch_size=3, label_name='survived', 
                                            na_value='', num_epochs=1, ignore_errors=True)
    
  • 文本:

    tf.data.TextLineDataset(filenames = ['x.csv', 'xx.csv']).skip(1) # 去掉第一行的header
    
  • 文件路径:

      tf.data.Dataset.list_files('./*/*.jpg')
    
  • tfrecords文件

    • 缺点:复杂,需要对样本构建tf.Example后压缩城字符串写到tfrecords文件,读取后再解析成tf.Example
    • 优点:压缩后文件较小,便于网络传播,加载速度快
管道提升
  • 模型训练耗时的两个部分
    • 数据准备:构建高效的数据管道来提升

      • 使用prefetch方法让数据准备和参数迭代两个过程相互并行

        # 模拟数据准备
        def generator():
            for i in range(10):
                time.sleep(2)
                yield i
        
        # 模拟参数迭代
        def train_step():
            time.sleep(1)
        
        # 一般情况下的串行,耗时:10 * 2 + 10 * 1 = 30s
        ds = tf.data.Dataset.from_generator(generator, output_types=(tf.int32))
        for x in ds:
            train_step()
        
        # prefetch实现数据准备和参数迭代相互并行,耗时:max(10 * 2, 10 * 1) = 20s
        for x in ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE):
            train_step()
        
      • 使用interleave方法可以让数据读取过程多进程执行,并将不同来源数据夹在一起

        ds_files = tf.data.Dataset.list_files("./data/titanic/*.csv")
        # flat_map单进程
        ds = ds_files.flat_map(lambda x:tf.data.TextLineDataset(x).skip(1))
        # interleave多进程
        ds = ds_files.interleave(lambda x:tf.data.TextLineDataset(x).skip(1))
        
      • 使用map时设置num_parallel_calls让数据转换过程多进程执行

        ds = tf.data.Dataset.list_files("./*/*.jpg")
        def load_image(img_path,size = (32,32)):
             label = 1 if tf.strings.regex_full_match(img_path,".*/automobile/.*") else 0 # 文件夹automobile下的label为1,否则为0
             img = tf.io.read_file(img_path)
             img = tf.image.decode_jpeg(img)
             img = tf.image.resize(img,size)
             return(img,label)
        # 单进程
        ds_map = ds.map(load_image)
        for _ in ds_map:
             opera
        # 多进程
        ds_map_parallel = ds.map(load_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)
        for _ in ds_map_parallel :
             opera
        
      • 使用cache方法让数据在第一个epoch后缓存到内存中,仅限于数据集不大的情况

        ds = tf.data.Dataset.from_generator(generator,output_types = (tf.int32)).cache()
        
      • 使用map转换时,先batch,然后采用向量化的转换方法对每个batch进行转换

        ds = tf.data.Dataset.range(100000)
        
        # 先map后batch
        ds_map_batch = ds.map(lambda x: x ** 2).batch(20)
        for x in ds_map_batch :
             opera
             
        # 先batch后map
        ds_batch_map = ds.batch(20).map(lambda x: x ** 2)
        for x in ds_batch_map :
             opera
        
    • 参数迭代:依赖GPU来提升

数据转换

  • map:同Python的map,将转化函数映射到数据集每一个元素
  • flat_map:映射后将多维压平成一维
  • interleave:类似flat_map,但可以将不同来源的数据夹在一起
  • filter:过滤某些元素
  • zip:横向铰合
  • concatenate:纵向铰合
  • reduce:归并
  • batch:构建批次,每次一批。逆操作unbatch
  • padded_batch:构建批次,类似batch,但可以填充到相同的形状
  • window:滑动窗口
  • shuffle:同np.shuffle
  • repeat:重复数据若干次
  • shard:采样,从某个位置开始隔固定距离采样一个元素
  • take:采样,类似top(n), head(n)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值