Pytorch-Lightning中的训练器—Trainer

Pytorch-Lightning中的训练器—Trainer

Trainer()

常用参数

参数名称含义默认值接受类型
callbacks添加回调函数或回调函数列表None(ModelCheckpoint默认值)Union[List[Callback], Callback, None]
enable_checkpointing是否使用callbacksTruebool
gpus使用的gpu数量(int)或gpu节点列表(list或str)None(不使用GPU)Union[int, str, List[int], None]
precision指定训练精度32(full precision)Union[int, str]
default_root_dir模型保存和日志记录默认根路径None(os.getcwd())Optional[str]
logger设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dirTrue(默认日志记录)Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]
max_epochs最多训练轮数(指定为**-1可以设置为无限次**)None(1000)Optional[int]
min_epochs最少训练轮数None(1)Optional[int]
max_steps最大网络权重更新次数-1(禁用)Optional[int]
min_steps最少网络权重更新次数None(禁用)Optional[int]
weights_save_path权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径None(default_root_dir)Optional[str]
log_every_n_steps更新n次网络权重后记录一次日志50int
auto_scale_batch_size自动搜索最佳batch_size并保存到模型的self.bacth_sizeFalseUnion[str, bool]
auto_lr_find自动搜索最佳学习率并存储到self.lrself.learing_rateFalseUnion[str, bool]
accumulate_grad_batches每k次batches累计一次梯度NoneUnion[int, Dict[int, int], None]
check_val_every_n_epoch每n个train epoch执行一次验证1int
num_sanity_val_steps开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据2int

额外的解释

  • 这里max_steps/min_steps中的step就是指的是优化器的step,优化器每step一次就会更新一次网络权重
  • 梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size。

Trainer.fit()

常用参数

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None

ckpt_path参数详解(从之前的模型恢复训练)

​ 使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。

示范
net = MyNet(...)
trainer = pl.Trainer(...)
# 假设模型保存在./ckpt中
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')
使用注意
  • 请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数

Trainer.test()

常用参数

参数名称含义默认值
modelLightningModule实例None(使用**fit()**传递的模型)
verbose是否打印测试结果True
dataloaders测试数据加载器(可以使用torch.utils.data.DataLoader)None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None
  • 16
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
这段代码的翻译如下: ``` if not args.two_steps: # 如果参数没有设置 two_steps,直接进行模型测试 trainer.test() step2_model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max", # 定义一个 ModelCheckpoint 回调函数,用于保存第二阶段训练的最佳模型 filename='{epoch}-{Step2Eval/f1:.2f}', dirpath="output", save_weights_only=True ) if args.two_steps: # 如果参数设置了 two_steps,进行两阶段训练 # 构建第二阶段训练所需的模型与训练器 # 使用 Step2Eval/f1 作为评估指标 lit_model_second = TransformerLitModelTwoSteps(args=args, model=lit_model.model, data_config=data_config) step_early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=6, check_on_train_epoch_end=False) callbacks = [step_early_callback, step2_model_checkpoint] # 定义回调函数列表,包括 EarlyStopping 和 ModelCheckpoint trainer_2 = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", gpus=gpu_count, accelerator=accelerator, plugins=DDPPlugin(find_unused_parameters=False) if gpu_count > 1 else None) # 构建训练器 trainer_2.fit(lit_model_second, datamodule=data) # 进行第二阶段训练 trainer_2.test() # 进行测试 ``` 该代码的功能是进行两阶段训练,如果参数没有设置 two_steps,则直接进行模型测试;如果设置了 two_steps,则进行第二阶段训练,训练过程使用 EarlyStopping 和 ModelCheckpoint 回调函数,并进行测试。其,第二阶段训练使用了一个新的模型。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值