Pytorch lightning下用改写的优化器SAM,需要 rewrite training_step部分
from utility.bypass_bn import enable_running_stats, disable_running_stats
from sam import SAM
class MyModel():
def __init__(self, **kwargs):
super().__init__(**kwargs)
#修改1 关闭自动优化
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
# 修改2 固定格式
opt= self.optimizers()
#first forward-backward step
enable_running_stats(model)
predictions = model(inputs)
loss = self.compute(loss)
loss.backward() #如果用manual_backward 会报错还不知原因
opt.first_step(zero_grad=True)
# second forward-backward step
disable_running_stats(model)
loss = self.compute(loss)
loss.backward()
opt.second_step(zero_grad=True)
return
def configure_optimizers(self):
lr = self.basic_lr_per_img * self.batch_size_per_device * self.gpus
optimizer = torch.optim.AdamW
optimizer = SAM(self.model.parameters(),lr=lr,weight_decay=1e-3)
return [optimizer]