有时模型训练时间过长,如遇到训练异常中断或者想调整超参,需要继续上次训练而不是重头开始。这里参照tf-slim的理念实现一种keras断点训练的功能:在做fine-tune时,如果保存模型路径中没有已保存的模型参数,则从google预训练模型中恢复参数,如果保存模型路径中有已保存的模型参数(之前已经训练过),则从保存的参数恢复模型参数继续上次训练。
下面先总结了模型参数初始化的几种方式(有基础的直接跳过),
一、模型参数初始化方式
1、随机初始化模型参数
base_model = VGG16(weights='None',include_top=False,input_shape=(224,224,3))
weights设置为None时,会随机初始化模型参数
2、从保存的模型参数中初始化变量
保存的模型参数可以是google的预训练模型,可以再这里下载google的预训练模型;也可以是自己保存的模型参数。
1)从自己保存的模型参数初始化变量
使用model.load_weights(weights_path)恢复模型参数
checkpoint_dir = '/data/sfang/logo_classify/keras_model/checkpoint/best.hdf5'
if os.path.exists(checkpoint_dir):
sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
model.load_weights(checkpoint_dir)
else:
sys.stdout.write('No checkpoint found')
2)Google预训练模型初始化参数
base_model = VGG16(weights='imagenet',include_top=False,input_shape=(224,224,3))
设置weights='imagenet',通过查看以下VGG16的源码知,设置weights='imagenet'也是使用model.load_weights(weights_path)恢复模型参数
# load weights
if weights == 'imagenet':
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='64373286793e3c8b2b4e3219cbf3544b')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='6d6bbae143d832006294945121d1f1fc')
model.load_weights(weights_path)
注意:
1)在只想恢复部分层参数,例如做fine-tune时,在VGG16的base_model后加了几层,只想使用google的预训练模型初始化base_model的参数时,使用model.load_weights(weights_path)会报错模型的层数与参数层数不相等,只需要指定按照层名称初始化参数就可以解决。model.load_weights(weights_path,by_name=True)
2)除model.load_weights(weights_path)外,也可以自己逐层初始化模型参数
for i in range(len(model_ pretrained.layers)-1):
model_new.layers[i].set_weights(model_pretrained.layers[i].get_weights())
二、断点训练
以VGG16 fine tune为例,在进行训练时,如果有已训练的模型参数保存,则从该文件中初始化模型参数继续上次训练,如果没有则从google的预训练模型初始化参数。
断点训练代码:
weights = 'imagenet'
include_top = False
# load weights
if os.path.exists(checkpoint_dir):
sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
model.load_weights(checkpoint_dir)
elif weights == 'imagenet':
sys.stdout.write('INFO:Load weights from imagenet\n')
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='64373286793e3c8b2b4e3219cbf3544b')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='6d6bbae143d832006294945121d1f1fc')
model.load_weights(weights_path,by_name=True)
数据组织形式:
文件目录:
其中0,1,2,3,4存储的是每类的照片。本例总共是五类。
项目全部代码:
#coding=utf-8
import keras
import os
import glob
import sys
import argparse
import tensorflow as tf
from matplotlib import pyplot as plt
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.models import Model
from keras.utils import get_file
from keras.layers import Dense,GlobalAveragePooling2D,Flatten,Dropout,Conv2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD,RMSprop
from keras.callbacks import TensorBoard
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report
from keras.backend.tensorflow_backend import set_session
# config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.4
# set_session(tf.Session(config=config))
IM_WIDTH, IM_HEIGHT = 299, 299 #InceptionV3指定的图片尺寸
FC_SIZE = 1024 # 全连接层的节点个数
NB_IV3_LAYERS_TO_FREEZE = 172 # 冻结层的数量
#训练集和测试集路径
train_dir = '/data/sfang/logo_classify/data/image_preprocessed/images/train/'
test_dir = '/data/sfang/logo_classify/data/image_preprocessed/images/test/'
val_dir = '/data/sfang/logo_classify/data/image_classify/val/'
checkpoint_dir = '/data/sfang/logo_classify/keras_model/checkpoint/best.hdf5'
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
#使用ImageDataGenerator类定义训练数据生产器
train_datagen = ImageDataGenerator(
#preprocessing_function=preprocess_input, #图像预处理函数
#rotation_range=30, #旋转角度范围
#width_shift_range=0.2, #
#height_shift_range=.2, #
#shear_range=.2,
#zoom_range=.2,
#horizontal_flip=True,
rescale=1./255
)
test_datagen = ImageDataGenerator(
#preprocessing_function=preprocess_input, #图像预处理函数
# rotation_range=30, #旋转角度范围
# width_shift_range=0.2, #
# height_shift_range=.2, #
# shear_range=.2,
# zoom_range=.2,
# horizontal_flip=True,
rescale=1./255
)
validation_datagen = ImageDataGenerator(
#preprocessing_function=preprocess_input, #图像预处理函数
# rotation_range=30, #旋转角度范围
# width_shift_range=0.2, #
# height_shift_range=.2, #
# shear_range=.2,
# zoom_range=.2,
# horizontal_flip=True,
rescale=1./255
)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224,224),
batch_size=64,
class_mode='categorical',
)
validation_generator = validation_datagen.flow_from_directory(
val_dir,
target_size=(224,224),
batch_size=16,
class_mode='categorical'
)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(224, 224),
batch_size=16,
class_mode='categorical'
)
def classify_model(base_model,nums_classes):
x = base_model.output
x = Conv2D(4096,(7,7),activation='relu',padding='valid',name='fc5')(x)
x = Flatten()(x)
x = Dense(4096, activation='relu', name='fc6')(x)
x = Dropout(rate=.5, name='dropout6')(x)
x = Dense(4096, activation='relu', name='fc7')(x)
x = Dropout(rate=.5, name='dropout7')(x)
predictions = Dense(nums_classes,activation='softmax',name='fc8')(x)
model = Model(input=base_model.input, output=predictions,name='my_vgg16')
return model
def define_trainable_layers(model,base_model):
# for layer in base_model.layers:
# layer.trainable = False
opt = RMSprop(lr=1e-4)
model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])
#定义模型
#bottleneck=vgg16
# 为避免重复初始化,不在这里进行初始化
base_model = VGG16(weights=None,include_top=False,input_shape=(224,224,3))
#定义模型,使用全局平均池化代替全连接层
model = classify_model(base_model,5)
weights = 'imagenet'
include_top = False
# load weights
# 如有已保存的模型参数则继续上次训练,如没有则从头开始训练
if os.path.exists(checkpoint_dir):
sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
model.load_weights(checkpoint_dir)
elif weights == 'imagenet':
sys.stdout.write('INFO:Load weights from imagenet\n')
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='64373286793e3c8b2b4e3219cbf3544b')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='6d6bbae143d832006294945121d1f1fc')
model.load_weights(weights_path,by_name=True)
#定义要训练的网络层
model.summary()
define_trainable_layers(model,base_model)
#监控某一项指标,当在一轮Epoch中该指标变优(loos变低或acc变高)则保存模型
checkpoint = ModelCheckpoint(checkpoint_dir,monitor='val_acc',
mode='min',save_best_only=True,verbose=1)
callbacks = [checkpoint,TensorBoard(log_dir='./log')]
history = model.fit_generator(
train_generator,
epochs=150,
shuffle=True,
callbacks=callbacks,
steps_per_epoch=1028,
validation_data=test_generator,
validation_steps=5
)