Keras学习之六:训练辅助及优化工具

11 篇文章 5 订阅

1 Callbacks

Callbacks提供了一系列的类,用于在训练过程中被回调,从而实现对训练过程进行观察和干涉。除了库提供的一些类,用户也可以自定义类。下面列举比较有用的回调类。

类名作用构造函数
ModelCheckpoint用于在epoch间保存要模型ModelCheckpoint(filepath, monitor=’val_loss’, save_best_only=False, save_weights_only=False, mode=’auto’, period=1)
EarlyStopping当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。EarlyStopping(monitor=’val_loss’, patience=0, mode=’auto’)
TensorBoard生成tb需要的日志TensorBoard(log_dir=’./logs’, histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)
ReduceLROnPlateau当指标变化小时,减少学习率ReduceLROnPlateau(monitor=’val_loss’, factor=0.1, patience=10, mode=’auto’, epsilon=0.0001, cooldown=0, min_lr=0)

示例:

from keras.callbacks import ModelCheckpoint

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

checkpointer = ModelCheckpoint(filepath="/tmp/weights.h5", save_best_only=True)
tensbrd = TensorBoard(logdir='path/of/log')
model.fit(X_train, Y_train, batch_size=128, callbacks=[checkpointer,tensbrd])

PS:加入tensorboard回调类后,就可以使用tensorflow的tensorboard命令行来打开可视化web服务了。

2 Application

本模块提供了基于image-net预训练好的图像模型,方便我们进行迁移学习使用。初次使用时,模型权重数据会下载到~/.keras/models目录下。

图像模型说明构造函数
InceptionV3InceptionV3(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
ResNet50ResNet50(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
VGG19VGG19(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
VGG16VGG16(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
XceptionXception(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None, classes=1000)

参数说明

参数说明
include_top是否保留顶层的全连接网络, False为只要bottleneck
weights‘imagenet’代表加载预训练权重, None代表随机初始化
input_tensor可填入Keras tensor作为模型的图像输出tensor
input_shape长为3的tuple,指明输入图片的shape,图片的宽高必须大于197
pooling特征提取网络的池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化
classes图片分类的类别数,当include_top=True weight=None时可用

关于迁移学习,可以参考这篇文章:如何在极小数据集上实现图像分类。里面介绍了通过图像变换以及使用已有模型并fine-tune新分类器的过程。

3 模型可视化

utils包中提供了plot_model函数,用来将一个model以图像的形式展现出来。此功能依赖pydot-ng与graphviz。
pip install pydot-ng graphviz

from keras.utils import plot_model
model = keras.applications.InceptionV3()
plot_model(model, to_file='model.png')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值