使用torch lightning进行多优化器训练
- 在configure_optimizers(self)中定义多优化器
def configure_optimizers(self):
param1 = list(self.encoder.parameters())+list(self.branch1.parameters())
param2 = list(self.encoder.parameters())+list(self.branch2.parameters())
optimizer1 = torch.optim.Adam(param1, lr=self.lr)
optimizer2 = torch.optim.Adam(param2, lr=self.lr)
scheduler1 = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer1)
scheduler2 = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer2)
return [optimizer1,optimizer2],[scheduler1,scheduler2]
- 在training_step(self,batch,batch_idx,optimizer_idx)中定义什么情况下用什么优化器
def training_step(self,batch,batch_idx,optimizer_idx):
ix,iy = batch
z = self.encoder(ix)
if optimizer_idx==0:
out1 = self.branch1(z)
loss1 = ...
return loss1
if optimizer_idx==1:
out2 = self.branch2(z)
loss2 = ...
return loss2