训练样本数量超过内存,采用python生成器
基于keras + tensorflow+Resnet-34
直接上代码
#coding=utf-8
import numpy as np
import os
import re
from keras.preprocessing.image import load_img,img_to_array
from keras.models import Model
from keras.layers import Conv2D,AveragePooling2D,Dense,BatchNormalization,Flatten,add,Input,MaxPooling2D,ZeroPadding2D
from keras.optimizers import SGD
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import LabelEncoder
seed = 7
np.random.seed(seed)
def getFileList(dir,fileList):
if os.path.isfile(dir):
fileList.append(dir.decode('gbk'))
elif os.path.isdir(dir):
for s in os.listdir(dir):
newDir = os.path.join(dir, s)
getFileList(newDir, fileList)
return fileList
def Conv2d_BN(x, nb_filter,kernel_size, strides=(1,1), padding='same'):
x = Conv2D(nb_filter,kernel_size,padding=padding,strides=strides,activation='relu')(x)
x = BatchNormalization(axis=3)(x)
return x
def Conv_Block(inpt,nb_filter,kernel_size,strides=(1,1), with_conv_shortcut=False):
x = Conv2d_BN(inpt, nb_filter=nb_filter, kernel_size=kernel_size, strides=strides, padding='same')
x = Conv2d_BN(x, nb_filter=nb_filter, kernel_size=kernel_size, padding='same')
if with_conv_shortcut:
shortcut = Conv2d_BN(inpt, nb_filter=nb_filter, strides=strides, kernel_size=kernel_size)
x = add([x, shortcut])
return x
else:
x = add([x, inpt])
return x
def Resnet():
inpt = Input(shape=(224,224,3))
x = ZeroPadding2D((3, 3))(inpt)
x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
# (56,56,64)
x = Conv_Block(x, nb_filter=64, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=64, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=64, kernel_size=(3, 3))
# (28,28,128)
x = Conv_Block(x, nb_filter=128, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = Conv_Block(x, nb_filter=128, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=128, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=128, kernel_size=(3, 3))
# (14,14,256)
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=256, kernel_size=(3, 3))
# (7,7,512)
x = Conv_Block(x, nb_filter=512, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = Conv_Block(x, nb_filter=512, kernel_size=(3, 3))
x = Conv_Block(x, nb_filter=512, kernel_size=(3, 3))
x = AveragePooling2D(pool_size=(7, 7))(x)
x = Flatten()(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inpt,x)
sgd = SGD(lr=0.001, momentum=0.9)
model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
model.summary()
return model
def generateData(batch_size):
label = ['cat','dog']
encoder = LabelEncoder()
encoder.fit(label)
train_img = getFileList(r'/media/wmy/document/BigData/kaggle/DogsvsCats/train', [])
while True:
np.random.shuffle(train_img)
batch = 0
X = []
y = []
for path in train_img:
img = load_img(path, target_size=(224, 224))
img = img_to_array(img)
label = re.findall(r'.+?/train/([a-z]+)', path)[0]
X.append(img)
y.append(label)
batch+=1
if batch % batch_size==0:
batch = 0
yield (np.array(X),np.array(encoder.transform(y)))
X = []
y = []
def generateTestData(batch_size):
#读取顺序不是按存储顺序
test_img = getFileList(r'/media/wmy/document/BigData/kaggle/DogsvsCats/test', [])
while True:
batch = 0
X = []
for path in test_img:
img = load_img(path, target_size=(224, 224))
img = img_to_array(img)
X.append(img)
batch += 1
if batch % batch_size == 0:
batch = 0
yield (np.array(X))
X = []
def train_model():
callback = []
model_check = ModelCheckpoint(filepath='best_params.h5',monitor='acc',save_best_only=True,mode='max')
callback.append(model_check)
model = Resnet()
model.fit_generator(generator=generateData(20),steps_per_epoch=25000,epochs=50,verbose=1,callbacks=callback,max_q_size=4)
def pred():
test_img = getFileList(r'/media/wmy/document/BigData/kaggle/DogsvsCats/test', [])
numbers = []
for path in test_img:
numb = re.findall(r'.+?/test/([0-9]+)', path)[0]
numbers.append(int(numb))
model = Resnet()
model.load_weights('best_params.h5')
ypred = model.predict_generator(generator=generateTestData(20),steps=12500,max_q_size=10,verbose=2)
ypred_index_data = np.c_[numbers,ypred]
ypred_index_data = list(ypred_index_data)
ypred_index_data.sort(key=lambda x:x[0])
np.savetxt("my_pred.csv", ypred_index_data, delimiter=',', header='id,label', comments='',fmt=['%d','%f'])
if __name__=='__main__':
train_model()
pred()