tensorflow2.2.0中的tf.data加快数据的处理速度

1.训练速度慢的原因

在我们训练网络的时候,是cpu先处理好数据然后送入网络中gpu进行计算,这样在cpu处理数据时非常的慢,会使得gpu有较长的空载时间。常规的训练方式如下。
在这里插入图片描述

2 加速方法

2.1 方法1 tf.data.prefetch()

这是可以使用tf.data.prefetch()方法,提前从数据集中取出若干数据放到内存中,这样可以使在gpu计算时,cpu通过处理数据,从而提高训练的速度。如下图所示
在这里插入图片描述

#手动设置
dataset = dataset.prefetch(config.batch_size).batch(config.batch_size).repeat(config.epochs)
#tensorflow自动划分
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE).batch(config.batch_size).repeat(config.epochs)

2.2 方法1 data.map()

使用data.map()函数,与data.perfecth()类似,也可以充分利用cpu的多核对数据进行预处理进行并行加速。

dataset = raw_dataset.map(_parse_example,num_parallel_calls=config.nums)

在这里插入图片描述
通过设置 Dataset.map() 的 num_parallel_calls 参数实现数据转换的并行化,上部分是未并行化的图示,下部分是 2 核并行的图示

当然,这里同样可以将 num_parallel_calls 设置为 tf.data.experimental.AUTOTUNE 以让 TensorFlow 自动选择合适的数值。
tensorflow官方给出了关于数据输入流水线处理方法和数据处理性能提升的方法。https://tensorflow.google.cn/guide/data_performance

3 完整代码

数据的格式为图片的字节数据 .dat文件
在这里插入图片描述
这里使用TFRecord处理数据,因为数据为字节型数据,所以使用struct解析数据,并将数据保存为.tfrecords格式。因为我的网络的目标有两个regression和classfication,因为两个数据的维度不一样,所以将两个数据分装成一个tensor时(iterator = tf.compat.v1.data.make_one_shot_iterator(dataset))出现错误,所以又自定义一个方法MakeData()来解析生成器中的数据,并将目标函数组成一个list为[reg,clas]。


size = config.hight * config.width * 3
def write_tfrecord(tfrecord_file,allData,trainORvalida=True):

    with tf.io.TFRecordWriter(tfrecord_file) as writer:
        Data = allData
        print(Data.shape[0])
        for i in range(Data.shape[0]):
            img_filename = Data[i][0]

            reg_filename = Data[i][1]
            label = Data[i][2]
            f = open(img_filename, 'rb')  # 读取数据集图片到内存,image 为一个 Byte 类型的字符串
            image = struct.unpack('f' * size, f.read(4 * size))
            reg = np.loadtxt(reg_filename, delimiter=',')  # 读取数据集图片到内存,image 为一个 Byte 类型的字符串
            feature = {                             # 建立 tf.train.Feature 字典
                'image': tf.train.Feature(float_list=tf.train.FloatList(value=image)),  # 图片是一个 Bytes 对象
                'reg': tf.train.Feature(float_list=tf.train.FloatList(value=reg.flatten())),  # 图片是一个 Bytes 对象
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)]))   # 标签是一个 Int 对象
            }
            example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example
            writer.write(example.SerializeToString())   # 将Example序列化并写入 TFRecord 文件

def read_tfrecode(tfrecord_file):

    raw_dataset = tf.data.TFRecordDataset(tfrecord_file)    # 读取 TFRecord 文件
    # print(ra
    feature_description = { # 定义Feature结构,告诉解码器每个Feature的类型是什么
        'image': tf.io.FixedLenFeature([192,336,3], tf.float32),
        'reg': tf.io.FixedLenFeature([1,54],tf.float32),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }

    def _parse_example(example_string): # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码
        feature_dict = tf.io.parse_single_example(example_string, feature_description)
        # feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])    # 解码JPEG图片

        label = tf.one_hot(feature_dict['label'], config.classNum)
        return feature_dict['image'],  feature_dict['reg'], label

    # dataset.prefetch()
    dataset = raw_dataset.map(_parse_example,num_parallel_calls=config.nums)
    # dataset = dataset.shuffle(buffer_size=10000)
    dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    dataset = dataset.prefetch(config.batch_size).batch(config.batch_size).repeat(config.epochs)
    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)


    return iterator

def MakeData(dataset):


    while 1:
        # for data in list(dataset.as_numpy_iterator()):
        data = dataset.get_next()
        img = data[0]
        reg = data[1]
        label = data[2]
        yield np.array(img),[np.array(reg),np.array(label)]


def MakeValidation():
    raw_dataset = tf.data.TFRecordDataset(config.validation_tfrecord_file)    # 读取 TFRecord 文件
    feature_description = { # 定义Feature结构,告诉解码器每个Feature的类型是什么
        'image': tf.io.FixedLenFeature([192,336,3], tf.float32),
        'reg': tf.io.FixedLenFeature([1,54],tf.float32),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }

    def _parse_example(example_string): # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码
        feature_dict = tf.io.parse_single_example(example_string, feature_description)
        feature_dict['image'] = tf.image.convert_image_dtype(feature_dict['image'],dtype=tf.uint8)
        # feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])    # 解码JPEG图片
        label = tf.one_hot(feature_dict['label'], config.classNum)
        # return tf.reshape(feature_dict['image'],[config.hight,config.width,3]),feature_dict['reg'],label
        return feature_dict['image'], feature_dict['reg'], label

    dataset = raw_dataset.map(_parse_example)

    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)

    validation_x = np.array([data[0] for data in list(dataset.as_numpy_iterator())])
    validation_y_reg = np.array([data[1] for data in list(dataset.as_numpy_iterator())])
    validation_label = np.array([data[2] for data in list(dataset.as_numpy_iterator())])

    return validation_x,validation_y_reg,validation_label
if __name__ == '__main__':
    from DataProcess.ImageProcess import splitDataTrainValidation

    _,validationData = splitDataTrainValidation()
    trainData = np.loadtxt(config.train,delimiter=',',dtype=np.str)
    #
    write_tfrecord(config.train_tfrecord_file,trainData)
    write_tfrecord(config.validation_tfrecord_file, validationData)


    validation_x,validation_label,validation_clas = MakeValidation()
    #

    dataset = read_tfrecode(config.train_tfrecord_file)
    for x,y in MakeData(dataset):
        print(x.shape,y[0].shape,y[1].shape)

4 总结

这是第一次写博客,写的可能不好。使用data.perfectch和TFRecord使网络训练变快了。下次更新使用tf.data.Dataset.from_generator处理数据。

  • 5
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值