Pytorch lightning-SAM Optimizer

Pytorch lightning下用改写的优化器SAM,需要 rewrite training_step部分

from utility.bypass_bn import enable_running_stats, disable_running_stats
from sam import SAM


class MyModel():
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        #修改1 关闭自动优化
        self.automatic_optimization = False
    def training_step(self, batch, batch_idx):
        # 修改2 固定格式
        opt= self.optimizers()
        
        #first forward-backward step
        enable_running_stats(model)
        predictions = model(inputs)
        loss = self.compute(loss)
        loss.backward() #如果用manual_backward 会报错还不知原因
        opt.first_step(zero_grad=True)
        
        # second forward-backward step
        disable_running_stats(model)
        loss = self.compute(loss)
        loss.backward()
        opt.second_step(zero_grad=True)
        return 
    def configure_optimizers(self):
        lr = self.basic_lr_per_img * self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW
        optimizer = SAM(self.model.parameters(),lr=lr,weight_decay=1e-3)
        return [optimizer]

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值