【TensoFlow学习笔记】进阶篇(一)— —如何制作自己的图片数据集(TFRecords文件生成训练集和测试集)

在实际项目中,我们往往需要对特定的数据进行分类,那么首先就需要根据需求制作数据集了。接下来我将以自己之前做的一个手势识别分类项目为例子,详细讲解制作图片数据集的具体操作过程。



1. 数据预处理

1.1 数据准备

在项目中,需要进行12种手势的分类。那么首先需要收集每一种类的图片(10张以上)到每个类别的文件夹中,文件夹以手势类别命名,图片不用命名。
在这里插入图片描述

1.2 数据增强

如果我们只用上面12个文件夹里面的120张图片数据,是无法训练出模型的,会使得模型过拟合,因此只能祭出 data augmentation(数据增强)神器了,通过旋转,平移,拉伸 等操作每张图片生成150张,这样图片就变成了18000张。下面是 data augmentation 的代码:
在深度学习中,我们经常需要用到一些技巧(比如将图片进行旋转,翻转等)来进行data augmentation, 来减少过拟合。 这里,我们主要用到的是深度学习框架keras中的ImageDataGenerator进行data augmentation。

datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest',
        cval=0,
        channel_shift_range=0,
        horizontal_flip=False,
        vertical_flip=False,
        rescale=None)

参数

  • rotation_range:整数,数据提升时图片随机转动的角度
  • width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度`
  • height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度 rescale:
  • 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)
  • shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
  • zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] =
    [1 - zoom_range, 1+zoom_range]
  • fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理
  • cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值
  • channel_shift_range: Float. Range for random channel shifts.
  • horizontal_flip:布尔值,进行随机水平翻转
  • vertical_flip:布尔值,进行随机竖直翻转 rescale: 重放缩因子,默认为None.

如果为None或0则不进行放缩,否则会将该数值乘到数据上.

2. TensorFlow读取数据的三种方式

在讲述在TensorFlow上的数据读取方式之前,有必要了解一下TensorFlow的系统架构,如下图所示:
在这里插入图片描述
TensorFlow的系统架构分为两个部分:
① 前端系统:提供编程模型,负责构造计算图;
② 后端系统:提供运行时环境,负责执行计算图。

在处理数据的过程当中,由于现在的硬件性能的极大提升,数值计算过程可以通过加强硬件的方式来改善,因此数据读取(即IO)往往会成为系统运行性能的瓶颈。在TensorFlow框架中提供了三种数据读取方式:

  • Preloaded data: 预加载数据
  • Feeding: placeholder, feed_dict由占位符代替数据,运行时填入数据
  • Reading from file: 从文件中直接读取

以上三种读取方式各有自己的特点,在了解这些特点或区别之前,需要知道TensorFlow是如何进行工作的。

TF的核心是用C++写的,这样的好处是运行快,缺点是调用不灵活。而Python恰好相反,所以结合两种语言的优势。涉及计算的核心算子和运行框架是用C++写的,并提供API给Python。Python调用这些API,设计训练模型(Graph),再将设计好的Graph给后端去执行。简而言之,Python的角色是Design,C++是Run。

1.1 Preload data: constant 预加载数据

特点:数据直接嵌入graph, 由graph传入session中运行

import tensorflow as tf

#设计graph
x = tf.constant([1,2,3], name='x')
y = tf.constant([2,3,4], name='y')
z = tf.add(x,y, name='z')

#打开一个session,计算z
with tf.Session() as sess:
    print(sess.run(z))


#运行结果如下:
#[3 5 7]

在设计Graph的时候,x和y就被定义成了两个有值的列表,在计算z的时候直接取x和y的值。

1.2 Feeding: placeholder, feed_dict

特点:由占位符代替数据,运行时填入数据

import tensorflow as tf

#设计graph,用占位符代替
x = tf.placeholder(tf.int16)
y = tf.placeholder(tf.int16)
z = tf.add(x,y, name='z')

#打开一个session
with tf.Session() as sess:
    #创建数据
    xs = [1,2,3]
    ys = [2,3,4]
    #运行session,用feed_dict来将创建的数据传递进占位符
    print(sess.run(z, feed_dict={
   x: xs, y: ys}))
#运行结果如下:
#[3 5 7]

1.3 Reading From File:直接从文件中读取

前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。

我们可以使用QueueRunner和Coordinator来实现bin文件,以及csv文件、TFRecord格式文件的读取,不过这里我们采用隐式创建线程的方法。在讲解具体代码之前,我们需要先来讲解关于TensorFlow中的队列机制和线程。

3. 队列和线程

直接从文件中读取数据的方式,需要设计成队列(Queue)的方式才能较好的解决IO瓶颈的问题,同时需要使用多线程来提高图片的批获取效率。
TensorFlow提供了多线程队列存取机制,主要涉及三个概念:Queue、QueueRunner及Coordinator.

3.1 队列(Queue)

队列是常用的数据结构之一,TensorFlow在各个设备(CPU、GPU、磁盘等)之间传递数据时使用了队列。例如,在CPU与GPU之间传递数据是非常缓慢的,为了避免数据传递带来的耗时瓶颈问题,采用异步的方式,CPU不断往队列传入数据,GPU不断从队列中读取数据。

在这里插入图片描述
在上图中,首先由一个单线程把文件名堆入队列,两个Reader同时从队列中取文件名并读取数据,Decoder将读出的数据解码后堆入样本队列,最后单个或批量取出样本(图中没有展示样本出列)。我们这里通过三段代码逐步实现上图的数据流,这里我们不使用随机,让结果更清晰。

  • 队列数据读取机制:
    tf.train.string_input_producer()
    tf.train.start_queue_runners()

  • 文件队列,通过tf.train.string_input_producer()函数来创建,文件名队列不包含文件的具体内容,只是在队列中记录所有的文件名,所以可以在这个函数中对文件设置多个epoch,并对其进行shuffle。这个函数只是创建一个文件队列,并指定入队的操作由几个线程同时完成。真正的读取文件名内容是从执行了tf.train.start_queue_runners()开始的,start_queue_runners返回一个op,一旦执行这个op,文件名队列就开始被填充了。

  • 内存队列,这个队列不需要用户手动创建,有了文件名队列后,start_queue_runners之后,Tensorflow会自己维护内存队列并保证用户时时有数据可读。

  • 详细内容请看这篇文章

3.2线程(Coordinator)

Coordinator用于管理线程,如管理线程同步等操作。

#创建一个协调器,管理线程
coord = tf.train.Coordinator()  
#启动QueueRunner, 此时文件名才开始进队。
threads=tf.train.start_queue_runners(sess=sess,coord=coord) 
.....
#关闭线程协调器
coord.request_stop()
coord.join(threads)

4. 异常处理

通过queue runners启动的线程不仅仅只处理推送样本到队列。他们还捕捉和处理由队列产生的异常,包括OutOfRangeError异常,这个异常是用于报告队列被关闭。 使用Coordinator对象的训练程序在主循环中必须同时捕捉和报告异常。 下面是对上面训练循环的改进版本。

try:
    for step in xrange(1000000):
        if coord.should_stop():
            break
        sess.run(train_op)
except Exception, e:
   # Report exceptions to the coordinator.
   coord.request_stop(e)

# Terminate as usual.  It is innocuous to request stop twice.
coord.request_stop()
coord.join(threads)

5. 生成和读取TFRecords文件

那么接下来就是要将图片数据生成文件格式了,我们这里采用的是TFRecord格式。

  • TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

  • TFRecords是一种二进制文件,可先将图片和标签制作成该格式的文件。使用TFRecords进行数据读取,会提高内存利用率。

  • 用 tf.train.Example的协议存储训练数据。训练数据的特征用键值对的形式表示。如:‘img_raw’:值 ‘label’:值,值是Byteslist/FloatList/int64List

  • 用SerializeToString()把数据序列化成字符串存储。

5.1 生成TFRecords文件

writer = tf.python_io.TFRecordWriter(tfRecordName)#新建一个writer

for 循环遍历每张图和标签:
   example = tf.train.Example(features=tf.train.Features(feature={
   
       'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
       'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
   }))#把每张图片和标签封装到example中,feature为字典形式
   writer.write(example.SerializeToString())#把example进行序列化
writer.close()

5.2 读取TFRecords文件

filename_queue = tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader()#新建一个reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
            features={
   
                'label': tf.FixedLenFeature([n_class], tf.int64),
                'img_raw': tf.FixedLenFeature([], tf.string)
            })#解序列化
img = tf.decode_raw(features['img_raw'], tf.uint8)#恢复img_raw到img
img.set_shape([img_height*img_width])#把img的形状变成一行784列
img = tf.cast(img, tf.float32) * (1. / 255)#把img的每个元素变成0-1之间的浮点数
label = tf.cast(features['label'], tf.float32)#把label的每个元素变成浮点数

完整代码

  • 数据增强(ImageDataGenerator.py)

from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
import os
import time

datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.15,
    height_shift_range=0.15,
    zoom_range=0.15,
    shear_range=0.2,
    horizontal_flip=True,
	fill_mode='nearest')

print("start.....: " + str((time.strftime('%Y-%m-%d %H:%M:%S'))))


dirs = os.listdir("D:/360MoveData/Users/ASUS/Desktop/gesture/音量减")
for filename in dirs:
    img = load_img("D:/360MoveData/Users/ASUS/Desktop/gesture/音量减/{}".format(filename))
    x = img_to_array(img)
    # print(x.shape)
    x = x.reshape((1,) + x.shape) #datagen.flow要求rank为4
    # print(x.shape)
    datagen
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值