TensorFlow1.x最佳实践:Dataset API+Keras Model+TF Train

本文分享TensorFlow 1.x中使用Dataset API进行数据读取,Keras构建模型以及TensorFlow API进行训练的最佳实践。内容包括Dataset API在大数据量读取中的优势,Keras模型的便捷搭建,以及自定义训练过程中的损失函数和学习率调整。适合有一定深度学习基础的读者快速实现和验证想法。
摘要由CSDN通过智能技术生成

前言

更新:想不到当时的想法还是挺好的,文中这种思路基本都出现在TensorFlow 2.0中,确实很爽。Dataset API+Keras Model+Train API搭配简直nice,论文的代码完全用这套重新实现非常舒服。加上面向对象的编码结构,你可以非常容易的自己去写复用性很高的工程,对于科研频繁换网络、试参数、快速验证想法实在是大有帮助,磨刀不误砍柴工,平时多花点时间构思代码结构挺重要的,看到重复三次以上的代码就想想如何消除冗余。

关于深度学习框架,主流的几个TensorFlow、PyTorch以及Keras都有所使用,由于在深度学习领域自己也只算个入门级选手,所以本文只从一个AI新手的角度去尝试分享一些使用框架编程的实践经验。至于标题最佳实践,那也纯粹有些哗众取宠之意,文章对于能够玩转各种框架API的大佬们,也许会贻笑大方。除此以外,本文相当于一个搬运工,并不讲解具体的使用细节,当然会推荐一些已经介绍的很好的文章,看完一定会有所收获。下面正式开始,希望能对大家有所帮助。

在三种框架的使用上,可能最难以上手使用的就是TensorFlow,毕竟在没有Eager Execution时,动态图的特性常常让人对网络调试摸不着头脑。Keras相对来说最容易上手,固定版式的代码,封装性极高,想要扩展对新手来说就有些难了。PyTorch在上手难易程度,扩展性方面都很棒,特别是 torch 张量可以即时看到,便于调试。

总的来说,如果想要从零到一的去写一个深度学习工程代码,我觉得PyTorch会相当的合适。但是我们往往是在别人的工作基础上进行改进,很多开源代码都是基于1.x版本的TensorFlow实现的,或者有些是在Keras基础上实现的,并非PyTorch,我们总不愿意去重新用PyTorch实现一遍,对于学术科研或许有些本末倒置了 😄

我们的目的是快速的实现自己的想法,基于TensorFlow1.x版本的框架实现自己的idea,快速试错。深度学习任务代码的编写,着重解决:

  1. 数据集制作与读取;
  2. 网络模型的搭建,其中包含了自定义网络层等各种复杂操作;
  3. 训练模型的代码,其中也包括自定义损失函数,动态调整学习率等。

针对这三方面,分别有着对应极为适合的方式去实现。

实践方案

a. 数据读取

往往我们的数据集存储在磁盘是直接以 jpg 或者 png 图片的形式,可能几万几十万张不等,标签信息可能也是图片或者存储在 txt 文档中的数据等等。当然,如果将这些零散的数据整合成类似于 npz 或者 TFRecord 这种,也是一样的。对于从硬盘读取大数据量的训练数据,往往都是需要多线程不断加载进行的,内存大小受限,不可能一次性加载全部数据。

TensorFlow在Dataset API之前,大多都是使用 QueueRunner 去搞定这件事。有兴趣可以去研究,这里随便贴一篇文章。老实说,这样的API有些难用,编码复杂性高,容易出错,至少我在平时的编码中确实会遇到数据读取队列出错的问题。相反PyTorch的数据读取方式就显得非常简单,有面向对象编程的那种感觉。TensorFlow在1.3版本之后引入了全新的读取数据API,也就是Dataset API。总的来说,更加的简洁明了,编码难度降低了很多。同样,这里推荐一篇文章,TensorFlow全新的数据读取方式:Dataset API教程。着重可能需要关注磁盘大数据量的读取和对数据的处理。

贴一个我自己写的代码,用于读取磁盘30万张 jpg 图片和对应 txt 标签。

class XxxDataloader:

    def __init__(self, config):
        self.config = config
        self.mode = config.mode

        # 数据路径
        self.img_path = config.img_path
        self.image_names_path = config.image_names_path
        self.gt_file = config.gt_file

        # 图片数据
        self.img_raw_batch = None
        self.img_aug_batch = None
        # 标签数据
        self.gt_batch = None  # ground truth

        # ===========》开始处理 ===========》

        # 读取图像名称和标签,image_names存放的是全部训练数据的
        image_names, gt = self._read_img_and_gt(self.image_names_path, self.gt_file)

        # 创建dataset, dataset中的一个元素是(image_name,, gt)
        dataset = tf.data.Dataset.from_tensor_slices((image_names, pts1_coordinates, gt_h4ps))
        # 通过图片名读取图片数据,并对数据进行处理
        dataset = dataset.map(self._parse_function)
        # 此时dataset中的一个元素是(image_batch, label_batch)
        if config.shuffle:
            dataset = dataset.shuffle(config.buffersize)
        dataset = dataset.batch(config.batch_size).repeat(config.train_epoch)

        # 从头到尾读取一次的iterator
        iterator = dataset.make_one_shot_iterator()

        # 从iterator里取出一个样本
        self.img_raw_batch, self.img_aug_batch, self.gt_batch = iterator.get_next()

    def _parse_function(self, image_name, gt):
    	# 获取图片路径,图片所在路径名称都存在一个txt中
        image_path = tf.string_join([self.img_path, image_name])

        # 读取图片RGB三通道
        image = self._read_image(image_path, [self.img_h, self.img_w], channels=3)
       
        # 数据增强
        random_aug = tf.random_uniform([], 0, 1)
        image_aug = tf.cond(random_aug < self.config.aug_ratio, lambda: self._augment_image(image), lambda: image)

        # 归一化等其他操作
        ......

        return image, image_aug, gt

    def _read_img_and_gt(self, filenames_file, gt_file):
        """
        读取图像名称数据、起始坐标点和ground truth
        :param filenames_file: 保存数据名称文件
        :param gt_file: 标签
        :return: 图的名称、标签
        """
        ..........
        return img_array, gt_array

    def _read_image(self, image_path, out_size, channels=3):
        """
        读取图像,并且resize成指定大小
        :param image_path: 图片路径
        :param out_size: 输出尺寸
        :param channels:
        :return:
        """
        image = tf.image.decode_jpeg(tf.read_file(image_path), channels=channels
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值