前言
更新:想不到当时的想法还是挺好的,文中这种思路基本都出现在TensorFlow 2.0中,确实很爽。Dataset API+Keras Model+Train API搭配简直nice,论文的代码完全用这套重新实现非常舒服。加上面向对象的编码结构,你可以非常容易的自己去写复用性很高的工程,对于科研频繁换网络、试参数、快速验证想法实在是大有帮助,磨刀不误砍柴工,平时多花点时间构思代码结构挺重要的,看到重复三次以上的代码就想想如何消除冗余。
关于深度学习框架,主流的几个TensorFlow、PyTorch以及Keras都有所使用,由于在深度学习领域自己也只算个入门级选手,所以本文只从一个AI新手的角度去尝试分享一些使用框架编程的实践经验。至于标题最佳实践,那也纯粹有些哗众取宠之意,文章对于能够玩转各种框架API的大佬们,也许会贻笑大方。除此以外,本文相当于一个搬运工,并不讲解具体的使用细节,当然会推荐一些已经介绍的很好的文章,看完一定会有所收获。下面正式开始,希望能对大家有所帮助。
在三种框架的使用上,可能最难以上手使用的就是TensorFlow,毕竟在没有Eager Execution时,动态图的特性常常让人对网络调试摸不着头脑。Keras相对来说最容易上手,固定版式的代码,封装性极高,想要扩展对新手来说就有些难了。PyTorch在上手难易程度,扩展性方面都很棒,特别是 torch
张量可以即时看到,便于调试。
总的来说,如果想要从零到一的去写一个深度学习工程代码,我觉得PyTorch会相当的合适。但是我们往往是在别人的工作基础上进行改进,很多开源代码都是基于1.x版本的TensorFlow实现的,或者有些是在Keras基础上实现的,并非PyTorch,我们总不愿意去重新用PyTorch实现一遍,对于学术科研或许有些本末倒置了 😄
我们的目的是快速的实现自己的想法,基于TensorFlow1.x版本的框架实现自己的idea,快速试错。深度学习任务代码的编写,着重解决:
- 数据集制作与读取;
- 网络模型的搭建,其中包含了自定义网络层等各种复杂操作;
- 训练模型的代码,其中也包括自定义损失函数,动态调整学习率等。
针对这三方面,分别有着对应极为适合的方式去实现。
实践方案
a. 数据读取
往往我们的数据集存储在磁盘是直接以 jpg
或者 png
图片的形式,可能几万几十万张不等,标签信息可能也是图片或者存储在 txt
文档中的数据等等。当然,如果将这些零散的数据整合成类似于 npz
或者 TFRecord
这种,也是一样的。对于从硬盘读取大数据量的训练数据,往往都是需要多线程不断加载进行的,内存大小受限,不可能一次性加载全部数据。
TensorFlow在Dataset API之前,大多都是使用 QueueRunner
去搞定这件事。有兴趣可以去研究,这里随便贴一篇文章。老实说,这样的API有些难用,编码复杂性高,容易出错,至少我在平时的编码中确实会遇到数据读取队列出错的问题。相反PyTorch的数据读取方式就显得非常简单,有面向对象编程的那种感觉。TensorFlow在1.3版本之后引入了全新的读取数据API,也就是Dataset API。总的来说,更加的简洁明了,编码难度降低了很多。同样,这里推荐一篇文章,TensorFlow全新的数据读取方式:Dataset API教程。着重可能需要关注磁盘大数据量的读取和对数据的处理。
贴一个我自己写的代码,用于读取磁盘30万张 jpg
图片和对应 txt
标签。
class XxxDataloader:
def __init__(self, config):
self.config = config
self.mode = config.mode
# 数据路径
self.img_path = config.img_path
self.image_names_path = config.image_names_path
self.gt_file = config.gt_file
# 图片数据
self.img_raw_batch = None
self.img_aug_batch = None
# 标签数据
self.gt_batch = None # ground truth
# ===========》开始处理 ===========》
# 读取图像名称和标签,image_names存放的是全部训练数据的
image_names, gt = self._read_img_and_gt(self.image_names_path, self.gt_file)
# 创建dataset, dataset中的一个元素是(image_name,, gt)
dataset = tf.data.Dataset.from_tensor_slices((image_names, pts1_coordinates, gt_h4ps))
# 通过图片名读取图片数据,并对数据进行处理
dataset = dataset.map(self._parse_function)
# 此时dataset中的一个元素是(image_batch, label_batch)
if config.shuffle:
dataset = dataset.shuffle(config.buffersize)
dataset = dataset.batch(config.batch_size).repeat(config.train_epoch)
# 从头到尾读取一次的iterator
iterator = dataset.make_one_shot_iterator()
# 从iterator里取出一个样本
self.img_raw_batch, self.img_aug_batch, self.gt_batch = iterator.get_next()
def _parse_function(self, image_name, gt):
# 获取图片路径,图片所在路径名称都存在一个txt中
image_path = tf.string_join([self.img_path, image_name])
# 读取图片RGB三通道
image = self._read_image(image_path, [self.img_h, self.img_w], channels=3)
# 数据增强
random_aug = tf.random_uniform([], 0, 1)
image_aug = tf.cond(random_aug < self.config.aug_ratio, lambda: self._augment_image(image), lambda: image)
# 归一化等其他操作
......
return image, image_aug, gt
def _read_img_and_gt(self, filenames_file, gt_file):
"""
读取图像名称数据、起始坐标点和ground truth
:param filenames_file: 保存数据名称文件
:param gt_file: 标签
:return: 图的名称、标签
"""
..........
return img_array, gt_array
def _read_image(self, image_path, out_size, channels=3):
"""
读取图像,并且resize成指定大小
:param image_path: 图片路径
:param out_size: 输出尺寸
:param channels:
:return:
"""
image = tf.image.decode_jpeg(tf.read_file(image_path), channels=channels