参考:
1、PyTorch torch.optim 传入两个网络参数
2、多任务学习pytorch使用不同学习率同时训练多个网络的方法
在pytorch 中,经常会遇到多个网络交织并存的情况,这样就会遇到多个网络需要同时训练的问题。我们如何解决这个问题呢?如下是解决办法。
首先需要导入package: itertools
import itertools
params2Solu_Net = inverse_dnn.DNN2Solu.parameters()
params2Para_Net = inverse_dnn.DNN2Para.parameters()
params2Net = itertools.chain(params2Solu_Net, params2Para_Net)
# 定义优化方法,并给定初始学习率
# optimizer = torch.optim.SGD(params2Net, lr=init_lr) # SGD
# optimizer = torch.optim.SGD(params2Net, lr=init_lr, momentum=0.8) # momentum
# optimizer = torch.optim.RMSprop(params2Net, lr=init_lr, alpha=0.95) # RMSProp
optimizer = torch.optim.Adam(params2Net, lr=learning_rate) # Adam
# 定义更新学习率的方法
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(epoch+1))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.995)