Keras中的MultiStepLR

本文介绍了一种在Keras中实现多步调整学习率的自定义调度器,通过解析命令行参数设置学习率衰减的周期和比例,适用于训练过程中的动态学习率调整。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:

1.代码

from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.keras import backend as K
import numpy as np
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
args, _ = parser.parse_known_args()


def get_lr_scheduler(args):
    lr_scheduler = MultiStepLR(args=args)
    return lr_scheduler


class MultiStepLR(Callback):
    """Learning rate scheduler.

    Arguments:
        args: parser_setting
        verbose: int. 0: quiet, 1: update messages.
    """

    def __init__(self, args, verbose=0):
        super(MultiStepLR, self).__init__()
        self.args = args
        self.steps = args.lr_decay_epochs
        self.factor = args.lr_decay_factor
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = self.schedule(epoch)
        if not isinstance(lr, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.lr, lr)
        print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
        if self.verbose > 0:
            print('\nEpoch %05d: MultiStepLR reducing learning '
                  'rate to %s.' % (epoch + 1, lr))

    def schedule(self, epoch):
        lr = K.get_value(self.model.optimizer.lr)
        for i in range(len(self.steps)):
            if epoch == self.steps[i]:
                lr = lr * self.factor

        return lr

2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)

callbacks = []
lr_scheduler = get_lr_scheduler(args=args)
callbacks.append(lr_scheduler)

...
model.fit_generator(train_generator,
                            steps_per_epoch=train_generator.samples // args.batch_size,
                            validation_data=test_generator,
                            validation_steps=test_generator.samples // args.batch_size,
                            workers=args.num_workers,
                            callbacks=callbacks,  # 你的callbacks, 包含了lr_scheduler
                            epochs=args.epochs,
                            )

大家可以拿去用~

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

老兵安帕赫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值