把组合损失中的权重设置为可学习参数

目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。

目录

1、构建组合损失类

2、调用组合损失类

3、为其构建优化器

4、梯度归零

5、跟新优化器参数

6、结果展示


1、构建组合损失类

每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。

# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        # 定义损失函数权重作为可训练参数
        self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True))  # 对抗损失的权重,初始值为0.2 
        self.w_con = nn.Parameter(torch.ones(1, requires_grad=True))  # 内容感知损失的权重,初始值为0.2
        self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True))  # 均方误差损失的权重,初始值为0.2
        self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True))  # 随机结构相似性损失的权重,初始值为0.2
        self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True))  # 边缘引导损失的权重,初始值为0.2


    def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):
        return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui

2、调用组合损失类

在计算组合损失之前,需要初始化类对象。

combinedloss = Loss.CombinedLoss()

unet_loss = self.combinedloss(
                            loss_adv = unet_gan_loss, 
                            loss_con = gen_content_loss, 
                            loss_mse = unet_criterion, 
                            loss_s3im = s3im_loss, 
                            loss_gui = guid_loss)

3、为其构建优化器

最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。

self.lr_weight_optimizer = optim.Adam(
            self.combinedloss.parameters(),
            lr = 1e-4,
            betas=(0.9, 0.999)
        )

4、梯度归零

在每次计算总损失之前,需要把每个优化器的梯度归零

self.lr_weight_optimizer.zero_grad()

5、跟新优化器参数

在总损失反向传播之后,需要对优化器的参数进行更新

self.lr_weight_optimizer.step()

6、结果展示

每个权重都会自动更新。 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值