基于pytorch 的 orthogonal_regularization(正交规范化)实现

认为参数需要满足一定条件,希望卷积层参数是正交的。
如果不是正交的,计算与正交之间的距离,然后作为损失进行优化。

本程序给出了orthogonal regularization的pytorch的实现,直接返回模型的损失。

import torch
def orthogonal_regularization(model, device, beta=1e-4):
    r"""
        author: Xu Mingle
        time: 2019年2月19日15:12:43
        input:
            model: which is the model we want to use orthogonal regularization, e.g. Generator or Discriminator
            device: cpu or gpu
            beta: hyperparameter
        output: loss
    """
    
    # beta * (||W^T.W * (1-I)||_F)^2 or 
    # beta * (||W.W.T * (1-I)||_F)^2
    # 若 H < W,可以使用前者, 若 H > W, 可以使用后者,这样可以适当减少内存
    
    
    loss_orth = torch.tensor(0., dtype=torch.float32, device=device)
    
    for name, param in model.named_parameters():
#         print('name is {}'.format(name))
#         print('shape is {}'.format(param.shape))
        if 'weight' in name and param.requires_grad and len(param.shape)==4:
        # 是weight,而不是bias
        # 当然是指定被训练的参数
        # 只对卷积层参数做这样的正则化,而不包括嵌入层(维度是2)等。
            
#             print('shape is {}'.format(param.shape))
#             print('name {}'.format(name))
            
            N, C, H, W = param.shape
#             print('param shape {}'.format(param.shape))
            
            weight = param.view(N * C, H, W)
#             print('flatten shape {}'.format(weight.shape))
            
            weight_squared = torch.bmm(weight, weight.permute(0, 2, 1)) # (N * C) * H * H
#             print('beta_squared shape {}'.format(weight_squared.shape))
            
            ones = torch.ones(N * C, H, H, dtype=torch.float32) # (N * C) * H * H
#             print('ones shape {}'.format(ones.shape))
            
            diag = torch.eye(H, dtype=torch.float32) # (N * C) * H * H
#             print('diag shape {}'.format(diag.shape))
            
            loss_orth += ((weight_squared * (ones - diag).to(device)) ** 2).sum()
            
    return loss_orth * beta
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值