【Pytorch】如何自定义带可学习超参的损失函数?以及参数不更新问题记录

自定义具有可学习参数的损失函数

我写自己的带可学习参数的损失函数是按照以下代码改的

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 
        # 这就是自己想要定义的可学习参数

        self.reset_parameters()

    def reset_parameters(self): # 参数初始化
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))    # 查看自己定义的参数
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

原理非常简单,只需要做两件事:

  1. torch.nn.Parameter:定义超参添加到模块并且给个初始值
  2. 添加到optimizer中,这样才能计算梯度反向传播(如果你的超参定义在网络net中,# 1 这样就行;如果你和我一样自己重新写了个脚本其中有超参数,就需要像 # 2 将其单独添加到optimizer)
# 1
optimizer = torch.optim.SGD( net.parameters()], lr=lr, momentum=momentum,weight_decay=weight_decay)
# 2
optimizer = torch.optim.SGD([{"params": net.parameters()}, {"params": criterion.parameters()}], lr=lr, momentum=momentum,weight_decay=weight_decay)

上面代码执行完全没有问题,参数也会更新,对其debug时候的参数列表是下面这样的,其属性是parameter并且grad是求出值了的,正常来讲debug就应该是这样
在这里插入图片描述

定义的可学习参数不更新

但是按照上述思路在自己的项目中定义超参之后debug发现一直是初始化值,并不会自动更新。下面是我按照上述修改的自己的损失函数:

import torch
from torch import nn

# 只自定义损失函数的前向计算公式,训练时会自动反向传播
class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))
        # 设置前置网络及 可学习参数
        self.lamda1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
        self.lamda2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
        self.beta1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
        self.beta2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)

        self.reset_params()

    def reset_params(self):
        # 初始化
        self.lamda1.data.fill_(1)
        self.lamda2.data.fill_(1)
        self.beta1.data.fill_(1)
        self.beta2.data.fill_(1)

    def forward(self, output, target):
        [c1, c2, c3, d1, d2, d3] = output
        [fun_1, fun_2, fun_3, D1, D2, D3] = target

        L = nn.MSELoss(reduction='none')

        # 置信度损失
        L1 = self.lamda1 * (
                (1 + fun_1 * self.beta1) * L(c1, fun_1) + (1 + fun_2 * self.beta1) * L(c2, fun_2) +
                (1 + fun_3 * self.beta1) * L(c3, fun_3)
        )

        # 直径预测损失
        L2 = self.lamda2 * (
                (1 + fun_1 * self.beta2) * L(d1, D1) + (1 + fun_2 * self.beta2) * L(d2, D2) + (
                1 + fun_3 * self.beta2) * L(d3, D3)
        )
        loss = L1 + L2
        print('lamda2:', self.lamda1.item())

        return loss

debug过程中发现并不会自动更新,以下是debug界面,可以看出和上面不同的是我设置的超参数beta1其属性就是一个变量一个tensor,并非是超参数parameter,并且其grad为None,因此其始终没有更新
在这里插入图片描述
我用下面代码打印我的损失函数的参数列表,结果印证的我的猜想:我设置的超参数其实并没有添加到参数列表中,所以才不会反向传播进行更新。

MyLoss = MyLoss()
for name, parameters in MyLoss.named_parameters():
    print(name, ':', parameters, parameters.size())
print(list(MyLoss.parameters()))

运行结果如下图:(正常来讲是应该包括我设置的四个超参数的列表,但实际上一个也没有)
在这里插入图片描述

解决办法

最后解决办法我也意想不到,这是我一开始定义参数的方式:(超参数设置不成功,定义成了变量,不更新)

self.lamda1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)

这是我现在定义的方式:(超参数,自动更新)

self.lamda1 = torch.nn.Parameter(torch.ones(1).to(device).requires_grad_())

更改之后同样的代码打印参数列表,结果如下:(参数已经添加到module中,参数也能正常更新了)
在这里插入图片描述
运行一下,看看是否真的更新了(这里只打印一个参数):搞定!
在这里插入图片描述

具体的原因我也还没搞清楚,有知道的大神烦请评论区赐教,后续弄明白原理了再来更新!

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值