[Keras学习]fit_generator浅析及完整实例


在做计算机视觉相关项目的时候,我们可以看到基本所有的keras的项目都用到了fit_generator这个函数来进行网络训练。在学习的过程中,发现没有讲解比较详细的实例来更好的来理解,于是结合别人的讲解及例子自己写一下比较详细的代码解析。
本文章代码主要参考:
Keras 系列(六) CNN 分类及fit_generator函数
在其代码的基础上修改了几处bug,并增加了数据验证集,数据增强等操作。
如何使用Keras fit和fit_generator(动手教程)
这是一篇讲解比较详细的文章,但我没有找到完整的项目代码。

为什么要用fit_generator

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。

这些数据集通常不是很具有挑战性,不需要任何数据增强。

但是,真实世界的数据集很少这么简单:

真实世界的数据集通常太大而无法放入内存中
它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力

所以在实际项目中,训练数据会很大,以前简单地使用model.fit将整个训练数据读入内存将不再适用,所以需要改用model.fit_generator分批次读取,并且我们还可以利用其进行数据增强。

fit_generator的参数

# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32 #batch size 

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
	validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
	epochs=EPOCHS)

该函数的主要参数有:

  1. generator:生成器函数,输出应该是形为(inputs,target)或者(inputs,targets,sample_weight)的元组,生成器会在数据集上无限循环

  2. steps_per_epoch: 顾名思义,每轮的步数,整数,当生成器返回 stesp_per_epoch次数据时,进入下一轮。

  3. epochs :整数,数据的迭代次数

  4. verbose:日志显示开关。0代表不输出日志,1代表输出进度条记录,2代表每轮输出一行记录

  5. validation_data:验证集数据.

  6. max_queue_size: 整数. 迭代骑最大队列数,默认为10

  7. workers: 最大进程数。在使用多线程时,启动进程最大数量(process-based threading)。未特别指定时,默认为1。如果指定为0,则执行主线程.

  8. use_multiprocessing: 布尔值。True:使用基于过程的线程

顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。

该函数本身是一个Python生成器。

Keras在使用.fit_generator训练模型时的过程:

  1. Keras调用提供给.fit_generator的生成器函数
  2. 生成器函数为.fit_generator函数生成一批大小为BS的数据
  3. .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
  4. 重复该过程直到达到期望的epoch数量
    您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。

为什么我们需要steps_per_epoch?

请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。

由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。

因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。

这些理论刚开始看起来特别的抽象与难以理解,让我们来结合具体例子来看看fit_generator是怎么工作的。

实例解析

这里我们用一个CNN花朵分类来实战。我是在colab上执行的,大家可以用notebook来运行,这样一块块比较方便观察输出与改错。其ipynb源文件我之后会传到我的GitHub上。

1.第一步 下载数据,加载相关库

数据是五种花的图片,每种花有320张图。我们用他们来训练我们的CNN网络。
下面有数据下载链接。

#数据下载地址:http://download.tensorflow.org/example_images/flower_photos.tgz
#加载相关模块
from skimage import io,transform
from pandas import Series, DataFrame 
import glob
import os
import numpy as np
from keras.models import Sequential
from keras.layers.core import Flatten,Dense,Dropout
from keras.layers.convolutional import Convolution2D,MaxPooling2D,ZeroPadding2D
from keras.optimizers import SGD,Adadelta,Adagrad
from keras.utils import np_utils,generic_utils
from keras.layers.advanced_activations import PReLU
from keras.layers.core import Flatten, Dense, Dropout 
from keras.layers.core import Dense, Dropout, Activation, Flatten 
from six.moves import range

2.第二步 设置路径和图片的形状大小

path='/content/gdrive/My Drive/keras/flower_photos' #修改为自己的解压的图片路径
w=182
h=182
c=3

3.第三步 读取所有图片数据

#读取图片
def read_img(path):
    cate=[path+'/'+x for x in os.listdir(path) if os.path.isdir(path+'/'+x)]
    print(cate)
    imgs=[]
    labels=[]
    n=0
    for idx,folder in enumerate(cate):
        for im in glob.glob(folder+'/*.jpg'):
            #print('reading the images:%s'%(im))
            img=io.imread(im)
            #print('before resize:',img.shape)  # 
            img=transform.resize(img,(w,h))
            #print('after:',img.shape)
            imgs.append(img)
            labels.append(idx)
            n=n+1
           
    return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
#调用函数
data,label=read_img(path)
#print('叠加之后的形状:',data.shape) 

4.第四步 打乱样本,转化标签

#打乱顺序,将标签转为二进制独热形式(0和1组成)
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]
from keras.utils.np_utils import to_categorical
labels_5= to_categorical(label,num_classes=5)
print(labels_5)

5.第五步 划分训练集,验证集

直接调用函数自动划分训练集x与验证集y。这里test_size=0.3说明验证集占总数据的30%。

from sklearn.model_selection import train_test_split
x_train, y_test, x_label, y_label = train_test_split(data,labels_5, test_size=0.3, random_state=42)

6.第六步 编写迭代器/生成器(重点)

def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False, aug=None):
    while 1:  # 要无限循环
        assert len(inputs) == len(targets)#判断输入数据长度和label长度是否相同
        if shuffle:
            indices = np.arange(len(inputs))
            np.random.shuffle(indices)
        for start_idx in range( len(inputs) - batch_size ):
            if shuffle:
                excerpt = indices[start_idx:start_idx + batch_size]
                if aug is not None:
                  (inputs[excerpt], targets[excerpt]) = next(aug.flow(inputs[excerpt],targets[excerpt], batch_size=batch_size))
            else:
                excerpt = slice(start_idx, start_idx + batch_size)
                if aug is not None:
                  (inputs[excerpt], targets[excerpt]) = next(aug.flow(inputs[excerpt],targets[excerpt], batch_size=batch_size))
            yield inputs[excerpt], targets[excerpt]#每次产生batchsize个数据

minibatches是我们自己定义的Keras生成器。
他负责读取我们的数据文件并将图像加载到内存中。它为我们的Keras .fit_generator函数生成批量数据。
我们也可以根据自己的需要编写符合自己需求的生成器函数,不过其格式都差不多,而功能都是一样的,不断产生批量数据给fit_generator调用。
这里我们定义了五个参数,但是其实只有前三个是必须的。inputs是我们的图像数据,targets使我们数据的标签。
最后两个是可选功能,shuffle是洗牌,可以打乱顺序,而aug则是fit_generator比fit而言特有的数据增强功能,aug :(默认为None)如果指定了扩充对象,那么我们将在生成图像和标签之前应用它,这个参数我们稍后细讲。

我们来一行行分析:

while 1: # 要无限循环
我们的Keras生成器必须无限循环。每次需要一批新数据时,fit_generator函数将调用我们的minibatches函数。

此外,Keras维护数据的缓存/队列,确保我们正在训练的模型始终具有要训练的数据。Keras不断保持这个队列的满载,所以即使你已经达到要训练的epoch总数,请记住Keras仍在为数据生成器提供数据,将数据保留在队列中。

始终确保您的函数返回数据,否则,Keras将错误地说它无法从您的生成器获取更多的训练数据.

assert len(inputs) == len(targets)#判断输入数据长度和label长度是否相同
这句话检查一下你的训练集数据和训练集标签长度是否一致.

if shuffle:
如果使用功能,产生一个乱序indices供下面使用。

for start_idx in range( len(inputs) - batch_size ):
从第一个第一个数据开始遍历,直至第len(inputs) - batch_size 个。因为我们每次取batch_size个数据,所以最后一次遍历只能到len(inputs) - batch_size个。

excerpt = slice(start_idx, start_idx + batch_size)
产生一个长度为batch_size的索引,因为我们的start_idx一直再增加,所以我们每次循环的索引也不一样。因此我们每次产生一组不同的数据。

if aug is not None:
(inputs[excerpt], targets[excerpt]) = next(aug.flow(inputs[excerpt],targets[excerpt], batch_size=batch_size))

这个功能是用来数据增强的。其具体函数我们下面细讲。

yield inputs[excerpt], targets[excerpt]#每次产生batchsize个数据
我们的生成器根据请求“生成”图像数组和调用函数标签列表。这里相当于用之前产生好的索引生成一组batch_size个的图片数据。
如果您不熟悉yield关键字,它用作Python Generator函数,作为一种方便的快捷方式,而不是构建具有较少内存消耗的迭代器类。如果不细究的话,在写我们函数时可以理解为return。

7.第七步 构建数据增强函数

如果你不需要数据增强功能的话,可以不加这段话。

from keras.preprocessing.image import ImageDataGenerator
# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

导入ImageDataGenerator,它包含数据增强和图像生成器功能。

aug,这是一个Keras ImageDataGenerator对象,用于图像的数据增强,随机平移,旋转,调整大小等。

执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。

应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。

根据提供给ImageDataGenerator的参数随机调整每批新数据。

注意我们只应该对训练集应用数据增强。

8.第八步 网络构建

CNN网络构建不是我们的重点,在这里不细讲。

model = Sequential() #第一个卷积层,4个卷积核,每个卷积核大小5*5。 
#激活函数用tanh #你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5)) 
model.add(Convolution2D(4, 5, 5,input_shape=(w, h,3))) 
model.add(Activation('relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) #第二个卷积层,8个卷积核,每个卷积核大小3*3。
#激活函数用tanh #采用maxpooling,poolsize为(2,2) 
model.add(Convolution2D(8, 3, 3)) 
model.add(Activation('relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
#第三个卷积层,16个卷积核,每个卷积核大小3*3 #激活函数用tanh 
#采用maxpooling,poolsize为(2,2) 
model.add(Convolution2D(16, 3, 3)) 
model.add(Activation('relu')) 
model.add(MaxPooling2D(pool_size=(2, 2)))
#全连接层,先将前一层输出的二维特征图flatten为一维的。
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
#多分类  
model.add(Dense(5)) # 共有5个类别
model.add(Activation('softmax'))
#print(model.summary())
model.compile(loss='categorical_crossentropy',optimizer='adam')#使用分类交叉熵(categorical_crossentropy),因为我们有超过2个类别,否则将使用二进制交叉熵(binary crossentropy )。

9.第九步 模型训练

#model.fit(data,labels_5,epochs=6,batch_size=2,verbose=2)#旧方法不再适用
H=model.fit_generator(minibatches(x_train,x_label,batch_size=6,shuffle=False,aug=aug),
                            steps_per_epoch=len(x_train)//6,
                            validation_data=minibatches(y_test,y_label,batch_size=6,shuffle=False,aug=None),
                            validation_steps=len(y_test)//6,
                            epochs=6)
#model.train_on_batch(minibatches(data, labels_5, batch_size=6, shuffle=False))

我们按照fit_generator的定义来应用,训练集与验证集都需要我们的生成器函数。steps_per_epoch从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
注意我们只对训练集应用数据增强。

10.第十步 评估训练结果

最后一步,使用训练历史字典H和matplotlib来生成图:

import matplotlib.pyplot as plt
N = 6 # N=epochs 
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig("plot.png")

这里因为我的epochs 不是很大,所以可能loss图像下降不算很明显,大家可以增大一下。

  • 29
    点赞
  • 97
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值