Keras Tuner自动调参工具使用入门教程

主体是翻译的Keras Tuner的说明:https://keras-team.github.io/keras-
tuner/documentation/tuners/
github地址:https://github.com/keras-team/keras-tuner

Keras Tuner 是一个分布式超参数优化框架,能够在定义的超参数空间里寻找最优参数配置。内置有贝叶斯优化、Hyperband 和随机搜索算法等算法。

不过原文只是举栗子,程序不能运行,改了不少,主要有以下几点:

  1. 原文没有数据输入,进行了增加,使其可以正确运行。
  2. 原文模型没有输入部分、数据有没对齐的地方进行了更改。
  3. 原文部分函数参数名与新版本不同,进行了修改,主要是Hyperband中的objective=‘val_Precision’、HyperXception中的classes。

显卡不行HyperXception、HyperResNet两个模型跑了一个晚上看着遥遥无期就停了,把cifar10的数据量从50000减到10000好像还是需要好久,只是学习下Keras Tuner就不跑了,感兴趣的同学可以试试,把结果分享一下。

Keras Tuner

Keras Tuner是用于Keras调参的分布式超参数优化框架,尤其是对于基于TensorFlow
2.0的tf.keras。Keras Tuner 可以轻松定义搜索空间,并利用内置算法找到较佳超参数的值,内置有贝叶斯优化、Hyperband和随机搜索算法。其全部文档和教程见Keras Tuner website.

安装

依赖:

  • Python 3.6
  • TensorFlow 2.0

安装命令:

pip install -U keras-tuner

使用源码安装:

git clone https://github.com/keras-team/keras-tuner.git
cd keras-tuner
pip install .

基本使用

这里展示了如何使用随机搜索为单层深度神经网络寻找最优超参。首先,定义一个模型。其输入一个可以采样超参的hp引用,如hp.Int('units', min_value=32, max_value=512, step=32)(特定范围内的整数)。该函数返回一个编译好的模型。

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# 获取MNIST 数据集.
(x_train, y_train), (x_val, y_val) = mnist.load_data()
x_train = np.expand_dims(x_train.astype('float32') / 255, -1)
x_val = np.expand_dims(x_val.astype('float32') / 255, -1)
# y_train = to_categorical(y_train, 10)
# y_val = to_categorical(y_val, 10)

定义模型构建函数

from kerastuner.tuners import RandomSearch

# 构建模型,传入hp参数,使用其定义需要优化的参数范围,构成参数空间
def build_model(hp):
    model = keras.Sequential()
    model.add(layers.Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(units=hp.Int('units',
                                        min_value=32,
                                        max_value=512,
                                        step=32),
                           activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    model.compile(
        optimizer=keras.optimizers.Adam(
            hp.Choice('learning_rate',
                      values=[1e-2, 1e-3, 1e-4])),
#         loss='categorical_crossentropy',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    return model

下一步,举例说明。需要设置模型构建函数,优化目标的名称(最大化还是最小化由内建度量得出),测试的总试验次数(max_trials),每次试验模型构建训练次数(executions_per_trial)。目前的优化器有RandomSearchHyperband
注意: 每次试验多次运行的目的是减少结果方差从而可以更精确的获取模型表现。如果想更快的得到结果,可以设置executions_per_trial=1每个模型配置只训练一轮。

# 选用随机搜索
tuner = RandomSearch(
    build_model,
    objective='val_accuracy',  #优化目标为精度'val_accuracy'(最小化目标)
    max_trials=5,   #总共试验5次,选五个参数配置
    executions_per_trial=3, #每次试验训练模型三次
    directory='my_dir',
    project_name='helloworld')

打印搜索空间

可以通过如下代码打印搜索空间综述:

tuner.search_space_summary()
Search space summary
|-Default search space size: 2
units (Int)
|-conditions: []
|-default: None
|-max_value: 512
|-min_value: 32
|-sampling: None
|-step: 32
learning_rate (Choice)
|-conditions: []
|-default: 0.01
|-ordered: True
|-values: [0.01, 0.001, 0.0001]

开始优化

然后开始搜索最佳的参数配置,search的调用方式与model.fit()相似。

tuner.search(x_train, y_train,
             epochs=5,
             validation_data=(x_val, y_val))

输出,删了好多,看不过来。

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [==============================] - ETA: 12:56 - loss: 2.2418 - accuracy: 0.156 - ETA: 57s - loss: 1.9715 - accuracy: 0.4083  - ETA: 34s - loss: 1.6704 - accuracy: 0.552 - ETA: 25s - loss: 1.4573 - accuracy: 0.626 - ETA: 20s - loss: 1.2937 - accuracy: 0.661 - ETA: 17s - loss: 1.1664 - accuracy: 0.696 - ETA: 16s - loss: 1.0756 - accuracy: 0.721 - ETA: 14s - loss: 0.9931 - accuracy: 0.743 - ETA: 13s - loss: 0.9200 - accuracy: 0.762 - ETA: 12s - loss: 0.8660 - accuracy: 0.776 - ETA: 12s - loss: 0.8175 - accuracy: 0.787 - ETA: 11s - loss: 0.7853 - accuracy: 0.793 - ETA: 11s - loss: 0.7477 - accuracy: 0.803 - ETA: 10s - loss: 0.7196 - accuracy: 0.808 - ETA: 10s - loss: 0.6958 - accuracy: 0.816 - ETA: 9s - loss: 0.6752 - accuracy: 0.820 - ETA: 9s - loss: 0.6564 - accuracy: 0.82 - ETA: 9s - loss: 0.6421 - accuracy: 0.82 - ETA: 9s - loss: 0.6252 - accuracy: 0.83 - ETA: 8s - loss: 0.6121 - accuracy: 0.83 - ETA: 8s - loss: 0.5969 - accuracy: 0.84 - ETA: 8s - loss: 0.5861 - accuracy: 0.84 - ETA: 8s - loss: 0.5765 - accuracy: 0.84 - ETA: 8s - loss: 0.5628 - accuracy: 0.84 - ETA: 7s - loss: 0.5490 - accuracy: 0.85 - ETA: 7s - loss: 0.5420 - accuracy: 0.85 - ETA: 7s - loss: 0.5327 - accuracy: 0.85 - ETA: 7s - loss: 0.5229 - accuracy: 0.85 - ETA: 7s - loss: 0.5149 - accuracy: 0.86 - ETA: 7s - loss: 0.5073 - accuracy: 0.86 - ETA: 7s - loss: 0.5009 - accuracy: 0.86 - ETA: 7s - loss: 0.4937 - accuracy: 0.86 - ETA: 6s - loss: 0.4892 - accuracy: 0.86 - ETA: 6s - loss: 0.4858 - accuracy: 0.86 - ETA: 6s - loss: 0.4785 - accuracy: 0.87 - ETA: 6s - loss: 0.4731 - accuracy: 0.87 - ETA: 6s - loss: 0.4657 - accuracy: 0.87 - ETA: 6s - loss: 0.4592 - accuracy: 0.87 - ETA: 6s - loss: 0.4547 - accuracy: 0.87 - ETA: 6s - loss: 0.4499 - accuracy: 0.87 - ETA: 6s - loss: 0.4456 - accuracy: 0.87 - ETA: 6s - loss: 0.4409 - accuracy: 0.88 - ETA: 5s - loss: 0.4364 - accuracy: 0.88 - ETA: 5s - loss: 0.4312 - accuracy: 0.88 - ETA: 5s - loss: 0.4271 - accuracy: 0.88 - ETA: 5s - loss: 0.4230 - accuracy: 0.88 - ETA: 5s - loss: 0.4190 - accuracy: 0.88 - ETA: 5s - loss: 0.4152 - accuracy: 0.88 - ETA: 5s - loss: 0.4118 - accuracy: 0.88 - ETA: 5s - loss: 0.4085 - accuracy: 0.88 - ETA: 5s - loss: 0.4046 - accuracy: 0.88 - ETA: 5s - loss: 0.4010 - accuracy: 0.89 - ETA: 5s - loss: 0.3975 - accuracy: 0.89 - ETA: 5s - loss: 0.3947 - accuracy: 0.89 - ETA: 5s - loss: 0.3920 - accuracy: 0.89 - ETA: 4s - loss: 0.3898 - accuracy: 0.89 - ETA: 4s - loss: 0.3858 - accuracy: 0.89 - ETA: 4s - loss: 0.3824 - accuracy: 0.89 - ETA: 4s - loss: 0.3794 - accuracy: 0.89 - ETA: 4s - loss: 0.3758 - accuracy: 0.89 - ETA: 4s - loss: 0.3746 - accuracy: 0.89 - ETA: 4s - loss: 0.3721 - accuracy: 0.89 - ETA: 4s - loss: 0.3700 - accuracy: 0.89 - ETA: 4s - loss: 0.3695 - accuracy: 0.89 - ETA: 4s - loss: 0.3671 - accuracy: 0.89 - ETA: 4s - loss: 0.3650 - accuracy: 0.89 - ETA: 4s - loss: 0.3628 - accuracy: 0.90 - ETA: 4s - loss: 0.3600 - accuracy: 0.90 - ETA: 4s - loss: 0.3574 - accuracy: 0.90 - ETA: 4s - loss: 0.3555 - accuracy: 0.90 - ETA: 3s - loss: 0.3543 - accuracy: 0.90 - ETA: 3s - loss: 0.3529 - accuracy: 0.90 - ETA: 3s - loss: 0.3522 - accuracy: 0.90 - ETA: 3s - loss: 0.3505 - accuracy: 0.90 - ETA: 3s - loss: 0.3498 - accuracy: 0.90 - ETA: 3s - loss: 0.3483 - accuracy: 0.90 - ETA: 3s - loss: 0.3463 - accuracy: 0.90 - ETA: 3s - loss: 0.3444 - accuracy: 0.90 - ETA: 3s - loss: 0.3417 - accuracy: 0.90 - ETA: 3s - loss: 0.3398 - accuracy: 0.90 - ETA: 3s - loss: 0.3386 - accuracy: 0.90 - ETA: 3s - loss: 0.3374 - accuracy: 0.90 - ETA: 3s - loss: 0.3356 - accuracy: 0.90 - ETA: 3s - loss: 0.3337 - accuracy: 0.90 - ETA: 3s - loss: 0.3320 - accuracy: 0.90 - ETA: 3s - loss: 0.3310 - accuracy: 0.90 - ETA: 3s - loss: 0.3296 - accuracy: 0.90 - ETA: 2s - loss: 0.3284 - accuracy: 0.90 - ETA: 2s - loss: 0.3273 - accuracy: 0.90 - ETA: 2s - loss: 0.3257 - accuracy: 0.90 - ETA: 2s - loss: 0.3247 - accuracy: 0.91 - ETA: 2s - loss: 0.3239 - accuracy: 0.91 - ETA: 2s - loss: 0.3227 - accuracy: 0.91 - ETA: 2s - loss: 0.3211 - accuracy: 0.91 - ETA: 2s - loss: 0.3196 - accuracy: 0.91 - ETA: 2s - loss: 0.3175 - accuracy: 0.91 - ETA: 2s - loss: 0.3166 - accuracy: 0.91 - ETA: 2s - loss: 0.3159 - accuracy: 0.91 - ETA: 2s - loss: 0.3148 - accuracy: 0.91 - ETA: 2s - loss: 0.3134 - accuracy: 0.91 - ETA: 2s - loss: 0.3124 - accuracy: 0.91 - ETA: 2s - loss: 0.3110 - accuracy: 0.91 - ETA: 2s - loss: 0.3102 - accuracy: 0.91 - ETA: 2s - loss: 0.3095 - accuracy: 0.91 - ETA: 1s - loss: 0.3084 - accuracy: 0.91 - ETA: 1s - loss: 0.3076 - accuracy: 0.91 - ETA: 1s - loss: 0.3064 - accuracy: 0.91 - ETA: 1s - loss: 0.3059 - accuracy: 0.91 - ETA: 1s - loss: 0.3049 - accuracy: 0.91 - ETA: 1s - loss: 0.3036 - accuracy: 0.91 - ETA: 1s - loss: 0.3031 - accuracy: 0.91 - ETA: 1s - loss: 0.3015 - accuracy: 0.91 - ETA: 1s - loss: 0.3011 - accuracy: 0.91 - ETA: 1s - loss: 0.3003 - accuracy: 0.91 - ETA: 1s - loss: 0.2991 - accuracy: 0.91 - ETA: 1s - loss: 0.2986 - accuracy: 0.91 - ETA: 1s - loss: 0.2975 - accuracy: 0.91 - ETA: 1s - loss: 0.2965 - accuracy: 0.91 - ETA: 1s - loss: 0.2951 - accuracy: 0.91 - ETA: 1s - loss: 0.2941 - accuracy: 0.91 - ETA: 1s - loss: 0.2932 - accuracy: 0.91 - ETA: 1s - loss: 0.2919 - accuracy: 0.91 - ETA: 1s - loss: 0.2912 - accuracy: 0.91 - ETA: 0s - loss: 0.2896 - accuracy: 0.91 - ETA: 0s - loss: 0.2886 - accuracy: 0.91 - ETA: 0s - loss: 0.2873 - accuracy: 0.92 - ETA: 0s - loss: 0.2870 - accuracy: 0.92 - ETA: 0s - loss: 0.2861 - accuracy: 0.92 - ETA: 0s - loss: 0.2851 - accuracy: 0.92 - ETA: 0s - loss: 0.2843 - accuracy: 0.92 - ETA: 0s - loss: 0.2831 - accuracy: 0.92 - ETA: 0s - loss: 0.2820 - accuracy: 0.92 - ETA: 0s - loss: 0.2816 - accuracy: 0.92 - ETA: 0s - loss: 0.2805 - accuracy: 0.92 - ETA: 0s - loss: 0.2798 - accuracy: 0.92 - ETA: 0s - loss: 0.2793 - accuracy: 0.92 - ETA: 0s - loss: 0.2790 - accuracy: 0.92 - ETA: 0s - loss: 0.2783 - accuracy: 0.92 - ETA: 0s - loss: 0.2771 - accuracy: 0.92 - ETA: 0s - loss: 0.2760 - accuracy: 0.92 - ETA: 0s - loss: 0.2754 - accuracy: 0.92 - 8s 141us/sample - loss: 0.2749 - accuracy: 0.9234 - val_loss: 0.1597 - val_accuracy: 0.9536
    Epoch 2/5
    60000/60000 [==============================] - ETA: 6s - loss: 0.3736 - accuracy: 0.90 - ETA: 6s - loss: 0.1718 - accuracy: 0.94 - ETA: 6s - loss: 0.1503 - accuracy: 0.95 - ETA: 6s - loss: 0.1499 - accuracy: 0.95 - ETA: 6s - loss: 0.1469 - accuracy: 0.95 - ETA: 7s - loss: 0.1446 - accuracy: 0.95 - ETA: 7s - loss: 0.1578 - accuracy: 0.95 - ETA: 7s - loss: 0.1543 - accuracy: 0.95 - ETA: 7s - loss: 0.1538 - accuracy: 0.95 - ETA: 6s - loss: 0.1476 - accuracy: 0.96 - ETA: 6s - loss: 0.1456 - accuracy: 0.96 - ETA: 6s - loss: 0.1555 - accuracy: 0.95 - ETA: 6s - loss: 0.1558 - accuracy: 0.95 - ETA: 6s - loss: 0.1536 - accuracy: 0.95 - ETA: 6s - loss: 0.1543 - accuracy: 0.95 - ETA: 6s - loss: 0.1536 - accuracy: 0.95 - ETA: 6s - loss: 0.1556 - accuracy: 0.95 - ETA: 6s - loss: 0.1552 - accuracy: 0.95 - ETA: 6s - loss: 0.1519 - accuracy: 0.95 - ETA: 6s - loss: 0.1522 - accuracy: 0.95 - ETA: 6s - loss: 0.1533 - accuracy: 0.95 - ETA: 6s - loss: 0.1514 - accuracy: 0.95 - ETA: 6s - loss: 0.1511 - accuracy: 0.95 - ETA: 6s - loss: 0.1513 - accuracy: 0.95 - ETA: 6s - loss: 0.1493 - accuracy: 0.95 - ETA: 6s - loss: 0.1485 - accuracy: 0.95 - ETA: 5s - loss: 0.1488 - accuracy: 0.95 - ETA: 5s - loss: 0.1485 - accuracy: 0.95 - ETA: 5s - loss: 0.1499 - accuracy: 0.95 - ETA: 5s - loss: 0.1502 - accuracy: 0.95 - ETA: 5s - loss: 0.1501 - accuracy: 0.95 - ETA: 5s - loss: 0.1488 - accuracy: 0.95 - ETA: 5s - loss: 0.1500 - accuracy: 0.95 - ETA: 5s - loss: 0.1498 - accuracy: 0.95 - ETA: 5s - loss: 0.1487 - accuracy: 0.95 - ETA: 5s - loss: 0.1478 - accuracy: 0.95 - ETA: 5s - loss: 0.1461 - accuracy: 0.95 - ETA: 5s - loss: 0.1459 - accuracy: 0.95 - ETA: 5s - loss: 0.1465 - accuracy: 0.95 - ETA: 5s - loss: 0.1462 - accuracy: 0.95 - ETA: 5s - loss: 0.1457 - accuracy: 0.95 - ETA: 5s - loss: 0.1451 - accuracy: 0.95 - ETA: 5s - loss: 0.1461 - accuracy: 0.95 - ETA: 5s - loss: 0.1454 - accuracy: 0.95 - ETA: 5s - loss: 0.1450 - accuracy: 0.95 - ETA: 4s - loss: 0.1445 - accuracy: 0.95 - ETA: 4s - loss: 0.1449 - accuracy: 0.95 - ETA: 4s - loss: 0.1448 - accuracy: 0.95 - ETA: 4s - loss: 0.1451 - accuracy: 0.95 - ETA: 4s - loss: 0.1449 - accuracy: 0.95 - ETA: 4s - loss: 0.1445 - accuracy: 0.95 - ETA: 4s - loss: 0.1450 - accuracy: 0.95 - ETA: 4s - loss: 0.1441 - accuracy: 0.95 - ETA: 4s - loss: 0.1441 - accuracy: 0.95 - ETA: 4s - loss: 0.1434 - accuracy: 0.95 - ETA: 4s - loss: 0.1429 - accuracy: 0.95 - ETA: 4s - loss: 0.1426 - accuracy: 0.95 - ETA: 4s - loss: 0.1429 - accuracy: 0.95 - ETA: 4s - loss: 0.1426 - accuracy: 0.95 - ETA: 4s - loss: 0.1415 - accuracy: 0.95 - ETA: 4s - loss: 0.1410 - accuracy: 0.95 - ETA: 4s - loss: 0.1404 - accuracy: 0.95 - ETA: 4s - loss: 0.1410 - accuracy: 0.95 - ETA: 4s - loss: 0.1411 - accuracy: 0.95 - ETA: 3s - loss: 0.1412 - accuracy: 0.95 - ETA: 3s - loss: 0.1421 - accuracy: 0.95 - ETA: 3s - loss: 0.1429 - accuracy: 0.95 - ETA: 3s - loss: 0.1428 - accuracy: 0.95 - ETA: 3s - loss: 0.1417 - accuracy: 0.95 - ETA: 3s - loss: 0.1415 - accuracy: 0.95 - ETA: 3s - loss: 0.1414 - accuracy: 0.95 - ETA: 3s - loss: 0.1407 - accuracy: 0.95 - ETA: 3s - loss: 0.1398 - accuracy: 0.95 - ETA: 3s - loss: 0.1401 - accuracy: 0.95 - ETA: 3s - loss: 0.1399 - accuracy: 0.95 - ETA: 3s - loss: 0.1396 - accuracy: 0.95 - ETA: 3s - loss: 0.1394 - accuracy: 0.95 - ETA: 3s - loss: 0.1391 - accuracy: 0.95 - ETA: 3s - loss: 0.1386 - accuracy: 0.95 - ETA: 3s - loss: 0.1383 - accuracy: 0.96 - ETA: 3s - loss: 0.1380 - accuracy: 0.96 - ETA: 3s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1375 - accuracy: 0.96 - ETA: 2s - loss: 0.1371 - accuracy: 0.96 - ETA: 2s - loss: 0.1375 - accuracy: 0.96 - ETA: 2s - loss: 0.1372 - accuracy: 0.96 - ETA: 2s - loss: 0.1373 - accuracy: 0.96 - ETA: 2s - loss: 0.1369 - accuracy: 0.96 - ETA: 2s - loss: 0.1367 - accuracy: 0.96 - ETA: 2s - loss: 0.1369 - accuracy: 0.96 - ETA: 2s - loss: 0.1373 - accuracy: 0.96 - ETA: 2s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1385 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1385 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1378 - accuracy: 0.95 - ETA: 2s - loss: 0.1380 - accuracy: 0.95 - ETA: 2s - loss: 0.1382 - accuracy: 0.95 - ETA: 2s - loss: 0.1379 - accuracy: 0.95 - ETA: 1s - loss: 0.1380 - accuracy: 0.95 - ETA: 1s - loss: 0.1378 - accuracy: 0.95 - ETA: 1s - loss: 0.1376 - accuracy: 0.95 - ETA: 1s - loss: 0.1375 - accuracy: 0.95 - ETA: 1s - loss: 0.1370 - accuracy: 0.95 - ETA: 1s - loss: 0.1367 - accuracy: 0.96 - ETA: 1s - loss: 0.1367 - accuracy: 0.96 - ETA: 1s - loss: 0.1364 - accuracy: 0.96 - ETA: 1s - loss: 0.1361 - accuracy: 0.96 - ETA: 1s - loss: 0.1359 - accuracy: 0.96 - ETA: 1s - loss: 0.1354 - accuracy: 0.96 - ETA: 1s - loss: 0.1351 - accuracy: 0.96 - ETA: 1s - loss: 0.1348 - accuracy: 0.96 - ETA: 1s - loss: 0.1345 - accuracy: 0.96 - ETA: 1s - loss: 0.1347 - accuracy: 0.96 - ETA: 1s - loss: 0.1350 - accuracy: 0.96 - ETA: 1s - loss: 0.1345 - accuracy: 0.96 - ETA: 1s - loss: 0.1340 - accuracy: 0.96 - ETA: 1s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1335 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1329 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1324 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1324 - accuracy: 0.96 - 8s 131us/sample - loss: 0.1322 - accuracy: 0.9613 - val_loss: 0.1164 - val_accuracy: 0.9655
    Epoch 3/5
    38400/60000 [==================>...........] - ETA: 6s - loss: 0.0618 - accuracy: 0.96 - ETA: 6s - loss: 0.0994 - accuracy: 0.95 - ETA: 6s - loss: 0.1115 - accuracy: 0.95 - ETA: 6s - loss: 0.0972 - accuracy: 0.96 - ETA: 6s - loss: 0.0947 - accuracy: 0.96 - ETA: 6s - loss: 0.0944 - accuracy: 0.97 - ETA: 6s - loss: 0.0890 - accuracy: 0.97 - ETA: 6s - loss: 0.0913 - accuracy: 0.97 - ETA: 6s - loss: 0.0905 - accuracy: 0.97 - ETA: 6s - loss: 0.0897 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0864 - accuracy: 0.97 - ETA: 6s - loss: 0.0863 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0904 - accuracy: 0.97 - ETA: 6s - loss: 0.0879 - accuracy: 0.97 - ETA: 6s - loss: 0.0899 - accuracy: 0.97 - ETA: 6s - loss: 0.0897 - accuracy: 0.97 - ETA: 6s - loss: 0.0905 - accuracy: 0.97 - ETA: 6s - loss: 0.0924 - accuracy: 0.97 - ETA: 6s - loss: 0.0922 - accuracy: 0.97 - ETA: 6s - loss: 0.0901 - accuracy: 0.97 - ETA: 5s - loss: 0.0895 - accuracy: 0.97 - ETA: 5s - loss: 0.0911 - accuracy: 0.97 - ETA: 5s - loss: 0.0914 - accuracy: 0.97 - ETA: 5s - loss: 0.0907 - accuracy: 0.97 - ETA: 5s - loss: 0.0904 - accuracy: 0.97 - ETA: 5s - loss: 0.0919 - accuracy: 0.97 - ETA: 5s - loss: 0.0904 - accuracy: 0.97 - ETA: 5s - loss: 0.0901 - accuracy: 0.97 - ETA: 5s - loss: 0.0907 - accuracy: 0.97 - ETA: 5s - loss: 0.0909 - accuracy: 0.97 - ETA: 5s - loss: 0.0919 - accuracy: 0.97 - ETA: 5s - loss: 0.0917 - accuracy: 0.97 - ETA: 5s - loss: 0.0923 - accuracy: 0.97 - ETA: 5s - loss: 0.0935 - accuracy: 0.97 - ETA: 5s - loss: 0.0935 - accuracy: 0.97 - ETA: 5s - loss: 0.0933 - accuracy: 0.97 - ETA: 5s - loss: 0.0929 - accuracy: 0.97 - ETA: 5s - loss: 0.0921 - accuracy: 0.97 - ETA: 5s - loss: 0.0926 - accuracy: 0.97 - ETA: 5s - loss: 0.0931 - accuracy: 0.97 - ETA: 5s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0932 - accuracy: 0.97 - ETA: 4s - loss: 0.0926 - accuracy: 0.97 - ETA: 4s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0929 - accuracy: 0.97 - ETA: 4s - loss: 0.0927 - accuracy: 0.97 - ETA: 4s - loss: 0.0936 - accuracy: 0.97 - ETA: 4s - loss: 0.0937 - accuracy: 0.97 - ETA: 4s - loss: 0.0936 - accuracy: 0.97 - ETA: 4s - loss: 0.0932 - accuracy: 0.97 - ETA: 4s - loss: 0.0937 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0933 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0930 - accuracy: 0.97 - ETA: 3s - loss: 0.0929 - accuracy: 0.97 - ETA: 3s - loss: 0.0926 - accuracy: 0.97 - ETA: 3s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0933 - accuracy: 0.97 - ETA: 3s - loss: 0.0930 - accuracy: 0.97 - ETA: 3s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0922 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0922 - accuracy: 0.97 - ETA: 3s - loss: 0.0924 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0919 - accuracy: 0.97 - ETA: 3s - loss: 0.0913 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0917 - accuracy: 0.97 - ETA: 3s - loss: 0.0915 - accuracy: 0.97 - ETA: 3s - loss: 0.0915 - accuracy: 0.97 - ETA: 3s - loss: 0.0912 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0922 - accuracy: 0.97 - ETA: 2s - loss: 0.0922 - accuracy: 0.9720WARNING:tensorflow:Can save best model only with val_accuracy available, skipping.

搜索过程具体如下:通过调用模型构建函数,使用hp跟踪的超参空间(搜索空间)中的参数配置,多次构建模型。优化器逐渐探索超参空间,记录每种配置的评估结果。

获取最佳模型

当搜索结束时,你可以得到最佳的模型。

# 返回最佳的两个模型
models = tuner.get_best_models(num_models=2)

打印结果

也可以打印结果综述。

tuner.results_summary()
Results summary
|-Results in my_dir/helloworld
|-Showing 10 best trials
|-Objective(name='val_accuracy', direction='max')
Trial summary
|-Trial ID: 71fc41aef4fc34c049c2f3b22a74252f
|-Score: 0.9792666435241699
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 256
Trial summary
|-Trial ID: b184c03f4c418071edd3b5afa390f952
|-Score: 0.9789333343505859
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 224
Trial summary
|-Trial ID: 4111f77d4d668a6a593030c902074bec
|-Score: 0.9765666127204895
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 128
Trial summary
|-Trial ID: 37e520f0bf10bbb4f3a32806700dce93
|-Score: 0.9545333385467529
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.0001
|-units: 160
Trial summary
|-Trial ID: 9ec337e0d6d8ff7e7c53fdf9931557a5
|-Score: 0.9280333518981934
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.0001
|-units: 32

可以在本例中的my_dir/helloworld,即directory/project_name模型保存文件夹下查看详细的日志、检查点信息。

设置条件超参

搜索空间可以设置条件超参。下面使用for循环创建一组可优化的层,每一层都包含可优化的units参数。这可以被推广到任何级别的相关参数,也可以递归。
注意所有的参数名称必须是唯一的(这里,对于第i个循环,内部参数命名为'units_'+str(i))。

# 构建模型
def build_model(hp):
    model = keras.Sequential()
    # 循环
    for i in range(hp.Int('num_layers', 2, 20)):
        # 循环中优化参数命名
        model.add(layers.Dense(units=hp.Int('units_' + str(i),
                                            min_value=32,
                                            max_value=512,
                                            step=32),
                               activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    model.compile(
        optimizer=keras.optimizers.Adam(
            hp.Choice('learning_rate', [1e-2, 1e-3, 1e-4])),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    return model

超参模型(HyperModel)子类

可以使用超参模型子类代替模型构建函数。
这会使超参模型的分享和重用变得简单。HyperModel子类只需要实现一个build(self, hp)方法。

from kerastuner import HyperModel

class MyHyperModel(HyperModel):

    def __init__(self, num_classes):
        self.num_classes = num_classes

    def build(self, hp):
        model = keras.Sequential()
        model.add(layers.Input(shape=(28, 28, 1)))
        model.add(layers.Flatten())
        model.add(layers.Dense(units=hp.Int('units',
                                            min_value=32,
                                            max_value=512,
                                            step=32),
                               activation='relu'))
        model.add(layers.Dense(self.num_classes, activation='softmax'))
        model.compile(
            optimizer=keras.optimizers.Adam(
                hp.Choice('learning_rate',
                          values=[1e-2, 1e-3, 1e-4])),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
        return model


hypermodel = MyHyperModel(num_classes=10)

tuner = RandomSearch(
    hypermodel,
    objective='val_accuracy',
    max_trials=10,
    directory='my_dir',
    project_name='helloworld1')

tuner.search(x_train, y_train,
             epochs=5,
             validation_data=(x_val, y_val))

预定义的优化应用

Keras Tuner包含了预定义的优化应用:HyperResNet 和HyperXception。这是可以用于机器视觉的随时可用的超参模型。他们使用loss="categorical_crossentropy"metrics=["accuracy"]进行预编译。

# 读取数据
from tensorflow.keras.datasets import cifar10
NUM_CLASSES = 10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 数据太多跑的太慢,减少数据
x_train = x_train[:10000]
x_test = x_test[:2000]
y_train = to_categorical(y_train, NUM_CLASSES)[:10000]
y_test = to_categorical(y_test, NUM_CLASSES)[:2000]

调用预定义的模型

from kerastuner.applications import HyperResNet
from kerastuner.tuners import Hyperband

hypermodel = HyperResNet(input_shape=(32, 32, 3), classes=10)

tuner = Hyperband(
    hypermodel,
    objective='val_accuracy',
    max_epochs=5,
    directory='my_dir',
    project_name='cifar10_resnet')

tuner.search(x_train, y_train,
             validation_data=(x_test, y_test))

优化部分参数

可以轻易的限定搜索空间去优化部分参数。如果已经有了超参模型,只想优化其部分参数(如学习率),可以通过传递hyperparameters参数给优化器构造器,也就是tune_new_entries=False来限定没有在hyperparameters中列出参数不参与优化。对于这些参数使用其默认值。

from kerastuner import HyperParameters
from kerastuner.applications import HyperXception
from kerastuner.tuners import Hyperband

hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)

hp = HyperParameters()
# 这将根据设定好的选项对`learning_rate` 参数进行优化
hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

tuner = Hyperband(
    hypermodel,
    hyperparameters=hp,
    # `tune_new_entries=False` 禁止没有列出的参数被优化
    tune_new_entries=False,
    objective='val_accuracy',
    max_epochs=5,
    directory='my_dir',
    project_name='cifar10_xception')

tuner.search(x_train, y_train,
             validation_data=(x_test, y_test))

想了解还有哪些可用参数?请阅读代码

设定参数默认值

参数默认值当在模型构建函数或者超参模型的build方法中注册一个超参时,可用设定其默认值:

hp.Int('units',
       min_value=32,
       max_value=512,
       step=32,
       default=128)

如果不设默认值,超参的默认值被也会被默认设置(liruInt,其默认值为最小值min_value)。

固定部分超参值

如果你想相反的操作——优化除了一个参数(如学习率)之外,超参模型中所有可用的参数?
传递一个包含一个(或者几个)Fixed项的超参hyperparameters ,并设定tune_new_entries=True

hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)

hp = HyperParameters()
hp.Fixed('learning_rate', value=1e-4)

tuner = Hyperband(
    hypermodel,
    hyperparameters=hp,
    tune_new_entries=True,
    objective='val_accuracy',
    max_epochs=5,
    directory='my_dir',
    project_name='cifar10_xception1')

tuner.search(x_train, y_train,
             validation_data=(x_test, y_test))

优化编译参数

如果你有了一个想要优化现有优化器、损失或度量的超参模型,你同样可以将这些参数传递给优化器构造器如下所示:

hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)

tuner = Hyperband(
    hypermodel,
    optimizer=keras.optimizers.Adam(1e-3),
    loss='mse',
    metrics=[keras.metrics.Precision(name='precision'),
             keras.metrics.Recall(name='recall')],
    objective='val_Precision',
    max_epochs=5,
    directory='my_dir',
    project_name='cifar10_xception2')

tuner.search(x_train, y_train,
             validation_data=(x_test, y_test))
  • 21
    点赞
  • 112
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值