1.训练过程中学习率衰减
if (self.e+1) > (self.num_epochs - self.num_epochs_decay):
g_lr -= (self.g_lr / float(self.num_epochs_decay))
d_lr -= (self.d_lr / float(self.num_epochs_decay))
self.update_lr(g_lr, d_lr)
print('Decay learning rate to g_lr: {}, d_lr:{}.'.format(g_lr, d_lr))
def update_lr(self, g_lr, d_lr):
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = g_lr
for param_group in self.d_A_optimizer.param_groups:
param_group['lr'] = d_lr
for param_group in self.d_B_optimizer.param_groups:
param_group['lr'] = d_lr