Pytorch-Lightning--Tuner

Pytorch-Lightning–Tuner

lr_find()

参数详解

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
datamoduleLightningDataModule实例None
min_lr学习率最小值1e-08
max_lr学习率最大值1
num_training测试学习率的训练轮数100
mode学习率寻找策略,分为指数(默认)和线性(linear)exponential
early_stop_threshold当任意一点的loss>=early_stop_threshold*best_loss时停止搜索,设置为None禁用该项4.0
update_attr将搜索到的学习率更新到模型参数中False

使用注意

  • 暂时只支持单个优化器
  • 暂不支持DDP

用法

使用self.learing_rateself.lr作为学习率参数

class LitModel(LightningModule):
    def __init__(self, learning_rate):
        self.learning_rate = learning_rate

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=(self.lr or self.learning_rate))


model = LitModel()

# 开启 auto_lr_find标志
trainer = Trainer(auto_lr_find=True)
# 寻找合适的学习率
trainer.tune(model)

使用其他的学习率变量名称

model = LitModel()

# 设置为自己的学习率超参数名称 my_value
trainer = Trainer(auto_lr_find="my_value")

trainer.tune(model)

使用lr_find()查看自动搜索学习率的结果

model = MyModelClass(hparams)
trainer = Trainer()

# 运行学习率搜索
lr_finder = trainer.tuner.lr_find(model)

# 查看搜索结果
lr_finder.results

# 绘制学习率搜索图,suggest参数指定是否显示建议的学习率点
fig = lr_finder.plot(suggest=True)
fig.show()

# 获取最佳学习率或建议的学习率
new_lr = lr_finder.suggestion()

# 更新模型的学习率
model.hparams.lr = new_lr

# 训练模型
trainer.fit(model)

scale_batch_size()

参数详解

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
datamoduleLightningDataModule实例None
mode学习率寻找策略,分为幂次方(默认)和二分(binsearch)power
steps_per_trial每次测试当前batch_size的训练step数量3
init_val初始batch_size大小2
max_trials算法结束前batch_size最大增量25
batch_arg_name存储batch_size的属性名'batch_size'
  • Returns:搜索结果

将在如下地方寻找batch_arg_name

  • model
  • model.hparams
  • trainer.datamodule (如果datamodule传递给了tune())

使用注意

  • 暂时不支持DDP模式

  • 由于需要使用模型的batch_arg_name属性,因此不能直接将dataloader直接传递给trainer.fit(),否则此功能将失效,需要在模型中加载数据

  • 原来模型中的batch_arg_name属性将被覆盖

  • train_dataloader()应该依赖于batch_arg_name属性

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
    

用法

使用Trainer中的auto_scale_batch_size属性
# 默认不执行缩放
trainer = Trainer(auto_scale_batch_size=None)

# 设置搜索策略
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")

# 寻找最佳batch_szie,并自动设置到模型的batch_size属性中
trainer.tune(model)
使用scale_batch_size()
# 返回搜索结果
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# 覆盖原来的属性(这个过程是自动的)
model.hparams.batch_size = new_batch_size
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值