目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。
目录
1、构建组合损失类
每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。
# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):
def __init__(self):
super(CombinedLoss, self).__init__()
# 定义损失函数权重作为可训练参数
self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True)) # 对抗损失的权重,初始值为0.2
self.w_con = nn.Parameter(torch.ones(1, requires_grad=True)) # 内容感知损失的权重,初始值为0.2
self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True)) # 均方误差损失的权重,初始值为0.2
self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True)) # 随机结构相似性损失的权重,初始值为0.2
self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True)) # 边缘引导损失的权重,初始值为0.2
def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):
return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui
2、调用组合损失类
在计算组合损失之前,需要初始化类对象。
combinedloss = Loss.CombinedLoss()
unet_loss = self.combinedloss(
loss_adv = unet_gan_loss,
loss_con = gen_content_loss,
loss_mse = unet_criterion,
loss_s3im = s3im_loss,
loss_gui = guid_loss)
3、为其构建优化器
最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。
self.lr_weight_optimizer = optim.Adam(
self.combinedloss.parameters(),
lr = 1e-4,
betas=(0.9, 0.999)
)
4、梯度归零
在每次计算总损失之前,需要把每个优化器的梯度归零
self.lr_weight_optimizer.zero_grad()
5、跟新优化器参数
在总损失反向传播之后,需要对优化器的参数进行更新
self.lr_weight_optimizer.step()
6、结果展示
每个权重都会自动更新。