在PyTorch中为可学习参数施加约束或正则项的方法

根据不同的需求,在PyTorch中有时需要为模型的可学习参数施加自定义的约束或正则项(regular term),下面具体介绍在PyTorch中为可学习参数施加约束或正则项的方法,先看一下为损失函数(Loss function)施加正则项的具体形式,如下为L2正则项:

Loss = L(w)+\lambda \sum _{i}w_{i}^{2}

在上式中,L(w)是训练误差关于可学习参数w的函数,右边的第二项表示L2正则项。在PyTorch中L2正则项是默认内置实现的,其中的weight_decay就表示L2正则项的\lambda超参数。具体如下:

optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=0.01)

根据不同的需求,怎样自定义自己的正则项函数呢?具体示例如下:

import torch

torch.manual_seed(1)

N, D_in, H, D_out = 10, 5, 5, 1
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

criterion = torch.nn.MSELoss()
lr = 1e-4
weight_decay = 0  # for torch.optim.SGD
lmbd = 0.9  # for custom L2 regularization

optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)

for t in range(100):
    y_pred = model(x)

    # Compute and print loss.
    loss = criterion(y_pred, y)

    optimizer.zero_grad()

    reg_loss = None
    for param in model.parameters():
        if reg_loss is None:
            reg_loss = 0.5 * torch.sum(param**2)
        else:
            reg_loss = reg_loss + 0.5 * param.norm(2)**2

    loss += lmbd * reg_loss

    loss.backward()

    optimizer.step()

for name, param in model.named_parameters():
    print(name, param)

在上述代码中,如下部分可根据自己的需求,自定义自己的正则项约束:

reg_loss = None
    for param in model.parameters():
        if reg_loss is None:
            reg_loss = 0.5 * torch.sum(param**2)
        else:
            reg_loss = reg_loss + 0.5 * param.norm(2)**2

 

如果您觉得我的文章对您有所帮助,欢迎扫码进行赞赏!

参考:

1. How does one implement Weight regularization (l1 or l2) manually without optimum?

2. torch.norm

3. How to add a L2 regularization term in my loss function?

  • 8
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值