使用torch lightning进行多优化器训练

使用torch lightning进行多优化器训练

  1. 在configure_optimizers(self)中定义多优化器
def configure_optimizers(self):
        param1 = list(self.encoder.parameters())+list(self.branch1.parameters())
        param2 = list(self.encoder.parameters())+list(self.branch2.parameters())

        optimizer1 = torch.optim.Adam(param1, lr=self.lr)
        optimizer2 = torch.optim.Adam(param2, lr=self.lr)

        scheduler1 = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer1)
        scheduler2 = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer2)
        return [optimizer1,optimizer2],[scheduler1,scheduler2]
  1. 在training_step(self,batch,batch_idx,optimizer_idx)中定义什么情况下用什么优化器
    def training_step(self,batch,batch_idx,optimizer_idx):
        ix,iy = batch 
        z = self.encoder(ix)
        
        if optimizer_idx==0:
			out1 = self.branch1(z)
            loss1 = ...
            return loss1
        
        if optimizer_idx==1:
            out2 = self.branch2(z)
            loss2 = ...
            return loss2
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch Lightning Metric 是 PyTorch Lightning 中用于评估模型性能的一种工具。Metric 可以用于监控训练过程中的指标,并在每个 epoch 结束时输出结果。PyTorch Lightning Metric 提供了多种内置的评估指标,如 accuracy、precision、recall、F1 等,并且可以自定义评估指标。 使用 PyTorch Lightning Metric 的基本步骤如下: 1. 定义 Metric 类,继承自 `pl.metrics.Metric` 2. 在类中实现 `update` 方法,用于更新评估指标 3. 在类中实现 `compute` 方法,用于计算最终的评估结果 4. 在 LightningModule 中使用 `self.log()` 方法输出评估结果 例如,下面是一个计算 accuracy 的 Metric 类的示例代码: ```python import torch import pytorch_lightning as pl class Accuracy(pl.metrics.Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds, target): preds = torch.argmax(preds, dim=1) self.correct += torch.sum(preds == target) self.total += target.numel() def compute(self): return self.correct.float() / self.total ``` 在 LightningModule 中使用该 Metric 可以如下使用: ```python class MyModel(pl.LightningModule): def __init__(self): super().__init__() self.accuracy = Accuracy() def training_step(self, batch, batch_idx): ... self.accuracy(preds, target) ... def training_epoch_end(self, outputs): ... self.log('train_acc', self.accuracy.compute(), on_step=False, on_epoch=True) ... ``` 在每个 epoch 结束时,`self.accuracy.compute()` 方法将计算 accuracy 并返回最终的评估结果。`self.log()` 方法用于输出评估结果,其中 `on_epoch=True` 表示只在每个 epoch 结束时输出,而不是每个 batch 结束时都输出。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值