基于torchtuples实现torch的ReduceLROnPlateau方法

ReduceLROnPlateau是torch的一个动态调整learning rate的方法,但如果项目中使用的是由torchtuples创建的optimizer,是无法直接使用该方法的,因为来自torch的该方法需要由torch创建的optimizer:

 

lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-6) 
# 此处的参数optimizer不接受torchtuples的实例

但是torchtuples提供了一个解决方案:继承tt提供的callback类,自行编写所需功能

以下是通过继承callback类,可以在torchtuples里使用的ReduceLROnPlateau方法:

import torchtuples as tt


class ReduceLROnPlateauCallback(tt.cb.Callback):
    def __init__(self, optimizer, factor=0.1, patience=10, min_lr=1e-6):
        super().__init__()
        self.optimizer = optimizer
        self.factor = factor
        self.patience = patience  # Number of epochs to wait for improvement before reducing the learning rate
        self.min_lr = min_lr
        self.best_loss = float('inf')  # Initialize to infinity
        self.wait = 0  # Number of epochs since the last improvement in loss

    def on_epoch_end(self):
        # Get the current valid loss
        current_loss = self.model.log.to_pandas().val_loss.iloc[-1]

        if current_loss < self.best_loss:
            self.best_loss = current_loss
            self.wait = 0
        else:
            self.wait += 1
            # wait reaches patience
            if self.wait >= self.patience:
                for param_group in self.optimizer.param_groups:
                    # Calculate the new learning rate and ensure it's not lower than the minimum learning rate
                    new_lr = max(param_group['lr'] * self.factor, self.min_lr)
                    # update lr and reset the wait
                    if new_lr < param_group['lr']:
                        param_group['lr'] = new_lr
                        self.wait = 0
                        print(f"Reducing learning rate to {new_lr}")

使用示例:

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)

optimizer = tt.optim.Adam()

model = CoxCC(net, optimizer)

callbacks = [ReduceLROnPlateauCallback(optimizer)]

model.fit(df_xtrain.values, y_train, batch_size, epochs, callbacks, val_data=val)

  • 9
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

奶昧蓝

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值