Keras中的MultiStepLR

Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:

1.代码

from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.keras import backend as K
import numpy as np
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
args, _ = parser.parse_known_args()


def get_lr_scheduler(args):
    lr_scheduler = MultiStepLR(args=args)
    return lr_scheduler


class MultiStepLR(Callback):
    """Learning rate scheduler.

    Arguments:
        args: parser_setting
        verbose: int. 0: quiet, 1: update messages.
    """

    def __init__(self, args, verbose=0):
        super(MultiStepLR, self).__init__()
        self.args = args
        self.steps = args.lr_decay_epochs
        self.factor = args.lr_decay_factor
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = self.schedule(epoch)
        if not isinstance(lr, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.lr, lr)
        print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
        if self.verbose > 0:
            print('\nEpoch %05d: MultiStepLR reducing learning '
                  'rate to %s.' % (epoch + 1, lr))

    def schedule(self, epoch):
        lr = K.get_value(self.model.optimizer.lr)
        for i in range(len(self.steps)):
            if epoch == self.steps[i]:
                lr = lr * self.factor

        return lr

2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)

callbacks = []
lr_scheduler = get_lr_scheduler(args=args)
callbacks.append(lr_scheduler)

...
model.fit_generator(train_generator,
                            steps_per_epoch=train_generator.samples // args.batch_size,
                            validation_data=test_generator,
                            validation_steps=test_generator.samples // args.batch_size,
                            workers=args.num_workers,
                            callbacks=callbacks,  # 你的callbacks, 包含了lr_scheduler
                            epochs=args.epochs,
                            )

大家可以拿去用~

发布了127 篇原创文章 · 获赞 979 · 访问量 142万+

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览