keras-fit_generator

keras的fit_generaror的函数官方文档有点坑,给的demo貌似没什么作用,自己在训练大数据集时容易爆内存

keras官网的demo你们感受一下:

https://keras-cn.readthedocs.io/en/latest/models/model/

网上搜了一下,资料甚少,而且都没什么作用,不得已自己实现generator函数,

首先讲讲为什么要大费周章的用fit_generator,而不是直接用fit呢?

我的同事写过一篇博客简单的介绍,这里参考下:

https://blog.csdn.net/jiangpeng59/article/details/79515680

简单来说,该函数就是解决训练集过大,无法一次性放入内容,每个batch的数据都从磁盘上获取

后来觉得keras灵活性不够,实现这个生成器太麻烦就直接改到pytorch了,任性吧。

 

言归正传,将下实现这个generator的整体思路,并附完整代码。

1、将训练集图片地址全部写入到一个文件中(也可以定义一个list存储,暂时还没试)

2、每次读取文件一行,然后通过地址读取图片,生成lable,注意这里lable必须是one-hot编码,不然报错:

ValueError: Error when checking target: expected activation_6 to have shape (2,) but got array with shape (1,)

3、设定batch_size,每次读取batch_size大小的图片之后,将图片和标签返回给模型.

实验结果:开始使用fit()函数训练,直接报错内存溢出,无法训练

使用fit_generator()之后,正常训练:

Epoch 2/10
17/20 [========================>.....] - ETA: 1:34 - loss: 0.6823 - acc: 0.2886^CTraceback (most recent call last):

 

自己写的demo仅供参考,发现问题欢迎大家提出来。

'''
2018-8-21
fit_generator
heq
'''
import os

from keras.optimizers import SGD
import numpy as np
import cv2
from keras.layers import Conv3D,MaxPooling3D,Flatten
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping

# 读取样本名称,然后根据样本名称去读取数据
def endwith(s,*endstring):
   resultArray = map(s.endswith,endstring)
   if True in resultArray:
       return True
   else:
       return False

#将训练集图片地址全部写入txt文件中
def write_imgpath():
    path="/home/hq/desktop/cat_dog/c-d-data/train"#训练集目录
    fp = open('path.txt', 'w')
    for file in os.listdir(path):
        file_path = os.path.join(path, file)
        for sub_file in os.listdir(file_path):
            if endwith(sub_file, 'jpg'):
                image_path = (os.path.join(file_path, sub_file))
                fp.write(image_path+"\n")
    fp.close()
#划分训练集和测试集
def write_test_imgpath():
    write_imgpath()
    path1='path.txt'
    fp1 = open('path_train.tst','w')
    fp2 = open('path_test.txt','w')
    count=0
    with open(path1) as f:
        for line in f:
            count+=1
            line = line.replace('\n', '')
            if count<=2500:
                fp2.write(line+"\n")
            elif count>=22501:
                fp2.write(line+"\n")
            else:
                fp1.write(line+"\n")
write_test_imgpath()

#传入batch_size和需要打开的文件地址
def generator_data(batch_size,path):
    list_x=[]
    list_y=[]
    count=0
    while True:

        with open(path) as f:
            for line in f:
                line = line.replace('\n', '')
                x,y = process_line(line)
                list_x.append(x)
                list_y.append(y)
                count+=1
                if count>=batch_size:
                    yield (np.array(list_x),np.array(list_y))
                    count=0
                    list_x=[]
                    list_y=[]
#读取图片和标签
def process_line(line):

    img = cv2.imread(line)
    #print(line)
    img = cv2.resize(img, (255, 255), interpolation=cv2.INTER_CUBIC)
    str=line.split('.')
    str1=str[0].split('/')
    length=len(str1)
    if str1[length-1]=='dog':
        label=1
    else:
        label=0
    emable=to_categorical(label, 2)#one-hot编码
    return img, emable

def VGG_16():
    model = Sequential()
    model.add(Conv2D(64, (3, 3),
           activation='relu',
           padding='same',
           name='block1_conv1',
           dim_ordering='tf',
           input_shape=(255,255,3)))
    model.add(Conv2D(64, (3, 3),
           activation='relu',
           padding='same',
           name='block1_conv2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool'))

    model.add(Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv1'))
    model.add(Conv2D(128, (3, 3),
           activation='relu',
           padding='same',
           name='block2_conv2'))
    print(model.output.shape)
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool'))
    model.add(Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv1'))
    model.add(Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv2'))
    model.add(Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool'))
    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv1'))
    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv2'))
    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool'))
    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv1'))
    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv2'))

    model.add(Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool'))

    model.add(Flatten(name='flatten'))
    model.add(Dense(4096, activation='relu',name='fc1'))

    model.add(Dense(4096, activation='relu', name='fc2'))

    model.add(Dense(2, activation='sigmoid', name='predictions'))
    print(model.summary())
    return model


model = VGG_16()
sgd = SGD(lr=0.000001,decay=1e-6,momentum=0.9,nesterov=True)
model.compile(optimizer=sgd,  loss='squared_hinge',  metrics=['accuracy'])
model.fit_generator(generator_data(32, "/home/hq/desktop/cat_dog/path_train.txt"),
                    steps_per_epoch=20000,#这里20000代表训练集的总数,设置为比总数小的都可以
                    epochs=10
                    )
loss ,accuracy= model.evaluate_generator(generator_data(32, "/home/hq/desktop/cat_dog/path_test.txt"), steps=5000)#5000为测试集的总数,设置为比5000小的都可以
print("loss is :",loss)
print("accuracy is :",accuracy)



 

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值