tf.data

一、tf.data简介

从读取数据到数据传入加速设备(GPU或TPU)的流程被称为输入管道(Input Pipeline)。TensorFlow提供tf.data API,可以帮助使用者打造灵活有效的输入管道,轻松处理大量数据、不同数据格式及复杂的转换。

输入管道可以分为以下3个步骤:
1、提取(Extraction)

从存储的地方(可以是SSD、HDD或远程存储位置)读取数据。

2、转换(Transformation)

使用CPU进行数据预处理,例如对图像进行翻转、裁剪、缩放和正则化等。

3、载入(Loading)

将转换后的数据加载到机器学习模型的加速设备。

上面这3个步骤主要是设备读取数据和CPU预处理在消耗时间,如果没有妥善地分工处理,就会造成当CPU在准备数据时,GPU在等待训练数据的产生(GPU处于空闲状态);反之,当GPU在训练时,CPU则处于空闲状态,如图2-16所示。如此,训练时间就会增加很多。

TensorFlow提供tf.data API,通过使用tf.data.Dataset.prefetch,行指令就可以让生成数据与训练数据同时进行,进而提升训练效率,如图2-17所示。

倘若输入管道的执行时间远比训练时间久,将发生如图2-18所示的情况,造成GPU或TPU加速器无法发挥全部的运算力,通常这种情况可能是读取文件太大或数据预处理太久造成的。

上述问题可以使用CPU多线程来解决,只需在调用map方法时加入 num parallel_calls 设置,即可启用并行处理数据的功能。通常num parallel_calls 会设置成计算机的核心数,图2-19所示为改善后的工作情况。

 二、基本操作

import tensorflow as tf
import os

def _parse_function(filename, label):
    image_string = tf.compat.v1.read_file(filename)
    # 将图像使用JPEG的格式解码从而得到图像对应的三维矩阵。
    image_decoded = tf.image.decode_jpeg(image_string,channels=3)
    # 通过tf.image.resize_images函数调整图像的大小。
    image_resized = tf.compat.v1.image.resize_images(image_decoded, [224, 224])
    return image_resized, label

file_path = r'E:\修图作品\2022-6-16'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)

dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function) #map对数据集应用自定义函数
dataset = dataset.batch(40) #设置每一批读取的数据量
dataset = dataset.repeat(2) #设置可以重复读取dataset n次

iterator = iter(dataset)
while 1:
    try:
        image, _ = next(iterator)
        print(image.shape)
    except StopIteration:   # python内置的迭代器越界错误类型
        print("iterator done")
        break;
'''
len(data) = 93
40+40+13=93
'''

 

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值