self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])# auto-encoder for domain a
self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])# auto-encoder for domain b
self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])# discriminator for domain a
self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])# discriminator for domain b
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
Setup the optimizers
beta1 = hyperparameters['beta1']
beta2 = hyperparameters['beta2']
dis_params =list(self.dis_a.parameters())+list(self.dis_b.parameters())
gen_params =list(self.gen_a.parameters())+list(self.gen_b.parameters())
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)