使用sklearn的GridSearchCV对keras进行多个超参数交叉验证

神经网络是玄学,很大一部分时间都是在花费参数的搭配选取上。如果在计算资源充足的情况下,使用网格搜索选取最优参数,就可以节省大量时间。下面给出示例:

Scikit-Learn里有一个API 为model.selection.GridSearchCV,可以将keras搭建的模型传入,作为sklearn工作流程一部分。

以下为keras的两个包装器,分别适用于分类和回归

keras.wrappers.scikit_learn.KerasClassifier(build_fn = None,** sk_params),它实现了Scikit-Learn分类器接口。

keras.wrappers.scikit_learn.KerasRegressor(build_fn = None,** sk_params),它实现了Scikit-Learn回归量接口。

适用手写数字mnist来测试

import tensorflow as tf
import keras
from keras import layers
from keras import models
from keras import utils
from keras.layers import Dense
from keras.models import Sequential
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import Activation
from keras.optimizers import RMSprop
from keras import datasets
from keras.callbacks import History
from keras import losses
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import numpy as np
#导入数据
mnist = keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train.shape, y_train.shape
#画出部分数据
plt.figure(figsize=(10,10))
for i in range(10):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)  #binary表示,画图的时,将大于0的统统用1表示
    plt.xlabel(y_train[i])

在这里插入图片描述

# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将二维的28*28的数据拉平
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

num_classes = 10

#使用one-hot编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

#确定shape后的x_train,y_train的维度
print(x_train.shape[1])
print(y_train.shape)
input_dim = x_train.shape[1]
num_classes = 10

def my_model( init='glorot_uniform'):
    model = Sequential()
    model.add(Dense(64, input_dim=input_dim, kernel_initializer=init, activation='relu'))
    model.add(Dropout(0.1))
    model.add(Dense(64, kernel_initializer=init, activation=tf.nn.relu))
    model.add(Dense(num_classes, kernel_initializer=init, activation=tf.nn.softmax))

    # 编译模型
    model.compile(loss='categorical_crossentropy', 
                 optimizer=RMSprop(),
                  metrics=['accuracy'])
	return model

导入必要的库

import numpy
from sklearn.model_selection import GridSearchCV  
from keras.wrappers.scikit_learn import KerasClassifier  
%%time    #魔法命令,将会给出当前cell的代码运行一次所花费的时间。
seed = 2
numpy.random.seed(seed)

#使用包装器创建model
model_init_batch_epoch_CV = KerasClassifier(build_fn=my_model, verbose=1)

#指定 初始化方式init_mode,和batches和epochs为超参数
init_mode = ['glorot_uniform', 'uniform'] 
batches = [128, 512]
epochs = [10, 20]

param_grid = dict(epochs=epochs, batch_size=batches, init=init_mode)

#注意,前面我们知道x_train的样本数为60000,传入x_train做训练时,因为cv=3,
#因此会划分为3个20000,即对应每个参数组合,只会训练40000个样本,另外20000样本做为测试,
#作为准确率参考。最后取3次训练的测试准确率均值为该参数组合最终准确率。
grid = GridSearchCV(estimator=model_init_batch_epoch_CV, 
                    param_grid=param_grid,
                    cv=3)
grid_result = grid.fit(x_train, y_train)

训练过程不展示了,来看看最后的结果

# print results
print(f'Best Accuracy for {grid_result.best_score_:.4} using {grid_result.best_params_}')
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print(f'mean={mean:.4}, std={stdev:.4} using {param}')

glorot_uniform为keras.layers中默认的参数初始化方式,看来效果还不错。
在这里插入图片描述

!!!Attention 计算机资源不够情况下,不要轻易做这样的网格搜索。

一般工业上,建议还是使用RandomizedSearchCV进行参数搜索。
RandomizedSearchCV的使用方法其实是和GridSearchCV一致的,但它以随机在参数空间中采样的方式代替了GridSearchCV对于参数的网格搜索,在对于有连续变量的参数时,RandomizedSearchCV会将其当作一个分布进行采样,这是网格搜索做不到的,它的搜索能力取决于设定的n_iter参数。

回头有机会再继续介绍!
参考链接:
https://blog.csdn.net/juezhanangle/article/details/80051256
https://mp.weixin.qq.com/s/iL23G0v_-HuyoQyl7PjKbg

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值