ReduceLROnPlateau是torch的一个动态调整learning rate的方法,但如果项目中使用的是由torchtuples创建的optimizer,是无法直接使用该方法的,因为来自torch的该方法需要由torch创建的optimizer:
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-6)
# 此处的参数optimizer不接受torchtuples的实例
但是torchtuples提供了一个解决方案:继承tt提供的callback类,自行编写所需功能
以下是通过继承callback类,可以在torchtuples里使用的ReduceLROnPlateau方法:
import torchtuples as tt
class ReduceLROnPlateauCallback(tt.cb.Callback):
def __init__(self, optimizer, factor=0.1, patience=10, min_lr=1e-6):
super().__init__()
self.optimizer = optimizer
self.factor = factor
self.patience = patience # Number of epochs to wait for improvement before reducing the learning rate
self.min_lr = min_lr
self.best_loss = float('inf') # Initialize to infinity
self.wait = 0 # Number of epochs since the last improvement in loss
def on_epoch_end(self):
# Get the current valid loss
current_loss = self.model.log.to_pandas().val_loss.iloc[-1]
if current_loss < self.best_loss:
self.best_loss = current_loss
self.wait = 0
else:
self.wait += 1
# wait reaches patience
if self.wait >= self.patience:
for param_group in self.optimizer.param_groups:
# Calculate the new learning rate and ensure it's not lower than the minimum learning rate
new_lr = max(param_group['lr'] * self.factor, self.min_lr)
# update lr and reset the wait
if new_lr < param_group['lr']:
param_group['lr'] = new_lr
self.wait = 0
print(f"Reducing learning rate to {new_lr}")
使用示例:
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)
optimizer = tt.optim.Adam()
model = CoxCC(net, optimizer)
callbacks = [ReduceLROnPlateauCallback(optimizer)]
model.fit(df_xtrain.values, y_train, batch_size, epochs, callbacks, val_data=val)