Keras花卉分类全流程(预处理+训练+预测)

本文的代码包括以下内容的示例:

1.用一个类封装自己的模型和训练、预测等过程
2.使用图片生成器(ImageDataGenerator)进行数据预处理,这一功能是Keras很 方面的地方,省去了自己进行数据处理的过程
3.在使用图片生成器的基础上进行训练的过程(fit_generator用法)
4.如何使用图片生成器进行预测和精度验证,这一部分包括predict_generator 的用法,以及相关的标签提取的过程,其中的一些细节。这一部分内容,在官方文档亦是没有提及。(我是没找到)

提醒

对于keras初学者来说,可能难以看懂,但是只要坚持看完代码,弄清楚数据生成器的使用,对使用keras进行分类的问题将是巨大帮助。这个东西我也是写加改弄了三四天,感觉帮助很大。

模型

本文采用的是简化版的VGG的A型的网络

优化方法:

本文采用的是Adamdelta

数据集

是Google的一个花卉数据集
共分为五类:
daisy,dandelion,roses,sunflower,tulips

我已手动分好了数据
当然比例有点问题
不过就是个小实验嘛

链接:https://pan.baidu.com/s/1ktu-6GOWnSYjuzHyxFeL7Q
提取码:bask

在给出代码之前,我想总结的几点经验:

1.模型难以收敛可能与图片尺寸有关系,较小的图片比较容易收敛,模型也相对容易训练。较大的图片使得模型的参数呈几何倍数增加,训练难度加大。

2.我上一篇博客写的是关于数据预处理的,是将原始图片分成了训练集、验证集、和测试集,存成了numpy数组的形式,进行保存。这样子不甚方便。
keras数据生成器的.flow_from_directroy()方法,直接生成各式数据。不需要进行人工的预处理。
数据需要组织成以下形式:
主文件夹下属train、validation、test三个文件夹,每个文件夹内下属多个类别名称文件夹,每个类别名称文件夹下下属该类别的图片。

3.全连接层与最后一个卷积层之间相连的时候,所需的显存极大,参数极多,容易造成资源枯竭的错误。尤其是当图片的尺寸较大的时候,更容易发生错误。
我的配置:
CPU:i7-8750H
内存:8GB
显卡:GTX1060 6GB显存
这样的配置,在一个7x7x512*4096的大张量下直接不行。最后进行了调整,适应了我的电脑配置。
通常出现: OOM when allocating tensor 就是这个地方有问题

4.loss不下降很有可能是你训练的轮次不够多。
因为从mnist的一下子训练完,到现在你训练一个比较大的网络,所需的时间是不一样的。需要转过弯来。尤其是图片尺寸和网络深度都相比较较大的时候,训练绝不是能够快速的就拟合的,而是需要一个缓慢的变化过程。

5.loss为nan有可能是你的数据没有成功的喂入网络中。尤其是一上来loss和acc就为nan的情况。

6.遇到问题,百度不到的,可以试试Google,找到答案的几率更大。

代码部分:

下面是代码:

#
#coding:utf-8
#
from __future__ import  print_function
import keras
from keras.models import Sequential
import numpy as np
from keras.layers import Dense,Conv2D,Dropout,BatchNormalization
from keras.layers import Activation,MaxPooling2D,Flatten
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from PIL import Image
from keras.utils.vis_utils import plot_model
import keras.backend as K

'''
Use 11 weight layers vursion to reduce
parameters to train fast
The dataset is small,only 3000 pictures
'''


class SampleVGGForFlower:
    '''
    This class is for flower dataset of Google
    
    '''
    def __init__(self,train_path,input_shape=(100,100,3),imagesize=(100,100),validation_path=None,test_path=None,train=True,show=True,plot=False):
    '''
	模型的传入图片的维度和尺寸都可以另行改变
	因为不同的尺寸训练的难度也很不同
	show:是否展示图片的训练过程的精度和损失值变化
	plot:是否保存权重文件
	train:是进行训练还是不进行训练只进行预测
	'''
        self.num_classes = 5
        self.batch_size = 16
        self.num_epoch = 200
        self.input_shape = input_shape
        self.imagesize = imagesize
        self.learningrate = 1
        self.trainpath = train_path
        self.validationpath = validation_path
        self.testpath = test_path
        self.show = show
        self.plot = plot
        
        self.model = self.build_model(plot)
        
        if train:
            self.model = self.train(self.model,show)
        else:
            self.model.load_weights('flowervgg.h5')
    
    def build_model(self,plot):
        '''
        Build the model with 11 layers.
        This is a sample vursion of VGG net
        Because of my device,some parameters are
        changed.
        我设备不行,必须去掉4096长的全连接层
        采取贯序模型
        '''
        model = Sequential()
        model.add(Conv2D(64,(3,3),padding='same',input_shape=self.input_shape))
        model.add(Activation('relu'))
        model.add(BatchNormalization())
        
        model.add(MaxPooling2D())
        
        model.add(Conv2D(128,(3,3),padding='same'))
        model.add(Activation('relu'))
        model.add(BatchNormalization())

        model.add(MaxPooling2D())

        model.add(Conv2D(256,(3,3),padding='same'))
        model.add(Activation('relu'))
        model.add(BatchNormalization())
        model.add(Dropout(0.3))

        model.add(Conv2D(256,(3,3),padding='same'))
        model.add(Activation('relu'))
        model.add(BatchNormalization())

        model.add(MaxPooling2D())

        model.add(Conv2D(512,(3,3),padding='same'))
        model.add(Activation('relu'))
        model.add(BatchNormalization())
        model.add(Dropout(0.3))

        model.add(Conv2D(512,(3,3),padding='same'))
        model.add(Activation('relu'))
        model.add(BatchNormalization())

        model.add(MaxPooling2D())
        
        model.add(Flatten())

        model.add(Dense(1000))
        model.add(Activation('relu'))
        model.add(BatchNormalization())
        model.add(Dropout(0.3))

        model.add(Dense(self.num_classes))
        model.add(Activation('softmax'))
        #有人说keras里softmax和交叉熵不能一起用
        #无稽之谈
        
        if plot:
            plot_model(model,'model.png',show_shapes=True,show_layer_names=True)

        return model
    
    def train(self,model,show=True):
        '''
        In this function,model will be trained.
        And if show was set,a image of the model
        will be stored at the directory.
        训练网络
        '''
        train_datagen = ImageDataGenerator(
            rotation_range=15,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=False
        )
        validation_datagen = ImageDataGenerator()
        train_generator = train_datagen.flow_from_directory(
            directory=self.trainpath,
            target_size=self.imagesize,
            color_mode='rgb',
            classes=['daisy','dandelion','roses','sunflowers','tulips'],
            class_mode='categorical',
            batch_size=self.batch_size,
            shuffle=True
        )
        validation_generator = validation_datagen.flow_from_directory(
            directory=self.validationpath,
            target_size=self.imagesize,
            classes=['daisy','dandelion','roses','sunflowers','tulips'],
            color_mode='rgb',
            class_mode='categorical',
            batch_size=self.batch_size,
            shuffle=True
        )
        opt = keras.optimizers.adadelta(lr = self.learningrate)
        model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])
        #设置早停功能,二十五轮没有下降就自动停止,调整参数
        earlystop = keras.callbacks.EarlyStopping(monitor='loss',patience=25,mode='auto')

        history = model.fit_generator(
            generator=train_generator,
            steps_per_epoch=90,#this parameter depend on dataset
            epochs=self.num_epoch,
            callbacks=[earlystop],
            validation_data=validation_generator,
            validation_steps=4 #this parameter depend on my dataset
        )

        model.save_weights('flowervgg.h5')#保存模型
		
		'''
		This is the way to show the accuracy or loss
		in the training.
		这里是一个绘制模型训练变化的示例
		'''
        if show:
            plt.plot(history.history['acc'])
            plt.plot(history.history['val_acc'])
            plt.title('model acc')
            plt.ylabel('acc')
            plt.xlabel('epoch')
            plt.legend(['train','validation'])
            plt.show()
        else:
            pass

        return model
    def predict(self):
        '''
        In this funtion,a path that testdata in should be provide,
        and the function could give a accuracy back.
        模型的预测的方法
        '''
        test_datagen = ImageDataGenerator()
        test_generator = test_datagen.flow_from_directory(
            directory=self.testpath,
            target_size=self.imagesize,
            color_mode='rgb',
            classes=['daisy','dandelion','roses','sunflowers','tulips'],
            class_mode='categorical',
            batch_size=1,
            shuffle=False#you must let shuffle False
        )
        
        #获取生成器生成的文件夹的全部文件名,计算数量
        filenames = test_generator.filenames
        nb_samples = len(filenames)
		
		#这里的step必须是全部文件的数量,不然没法算精度了
		#返回预测值的一维数组
        predict = self.model.predict_generator(
            generator=test_generator,
            steps=nb_samples
        )

        #classes属性提供了文件夹内全部文件的类别
        #是一个一维数组
        true_labels = (test_generator.classes)
        pre_labels = np.argmax(predict,axis=-1)
        correct_label = np.equal(true_labels,pre_labels)
        accuracy = np.mean(correct_label)

        print('Accuracy is: ',accuracy)


if __name__ == '__main__':
    
    #从这里传入参数
    model = SampleVGGForFlower(
        train_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\train',
        validation_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\validation',
        test_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\test',
        train = False
        )
    #你也可以不预测,把这行注释掉即可
    #尤其是你没有传入testpath的情况下
    model.predict()

训练集准确率能够达到90%以上,测试集比较小,准确率在75%左右,估计是过拟合了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值