深度神经网络中的正交规范化

https://zhuanlan.zhihu.com/p/98873800

为什么要用正交规范化?

在神经网络中,我们都会有矩阵乘法,即使是卷积神经网络CNN中。正交阵的好处是,如果一个矩阵与一个正交阵相乘,这个矩阵的范数不会变化。(如,一个二维向量在一个二维的坐标系如何旋转长度是不会发生变化的。附:旋转矩阵是正交矩阵,正交变换不改变矩阵的F范数、2范数,不改变向量的2范数)

正交阵的这个特性在梯度反向传播时有一定好处,特别是梯度爆炸和梯度消散的情况。

为什么矩阵范数不变会有助于梯度爆炸或梯度消散呢?

举个例子,以L2范数为例,L2范数等于矩阵所有元素的平方和的平方根。如果能让卷积核这个矩阵是正交阵,那么特征图(feature map)这个矩阵的范数就不变了。特征图的L2范数不变化,使得前后特征图的最大值都不会很大。而梯度反向传播时权重的梯度会用到输入并与之相乘,当然这有助于梯度爆炸。输入值也不会变得越来越小,这样也有助于梯度的保持(不考虑梯度越往后本身就越小,只考虑因为要与输入相乘这个因素)

为何保证范数就可以缓解梯度消失和爆炸,现在也还没有完全分析透彻。BigGAN文章中就试图用谱范数来追踪GAN的模型塌陷。因此,希望在运行过程中,卷积核是一个正交阵。保持特征图的范数。

如何运用?

正交阵的一个条件是: 其中W是正交阵,I 是单位阵。但是,实际过程中,卷积核W并不是一个正交阵,因此等式是一个非0矩阵。可以认为非0元素越多,我们越不喜欢这个卷积核。如2017年LCLR文章Neural Photo Editing with Introspective Adversarial Networks(文末附链接)就是使用了L1范数作为损失。就是对每个元素取绝对值,然后求和。
在这里插入图片描述

BigGAN做了一些改进。它不使用L1范数,而是L2范数。并且认为对角线上有其他约束。下式中,I矩阵 表示每个元素都为1.
在这里插入图片描述

相关代码

import torch
def orthogonal_regularization(model, device, beta=1e-4):
    # beta * (||W^T.W * (1-I)||_F)^2 or
    # beta * (||W.W.T * (1-I)||_F)^2
    # 若 H < W,可以使用前者, 若 H > W, 可以使用后者,这样可以适当减少内存
    loss_orth = torch.tensor(0., dtype=torch.float32, device=device)
    for name, param in model.named_parameters():
        if 'weight' in name and param.requires_grad and len(param.shape)==4:
        # 是weight,而不是bias
        # 当然是指定被训练的参数
        # 只对卷积层参数做这样的正则化,而不包括嵌入层(维度是2)等。
            N, C, H, W = param.shape
            weight = param.view(N * C, H, W)
            weight_squared = torch.bmm(weight, weight.permute(0, 2, 1)) # (N * C) * H * H
            ones = torch.ones(N * C, H, H, dtype=torch.float32) # (N * C) * H * H
            diag = torch.eye(H, dtype=torch.float32) # (N * C) * H * H
            loss_orth += ((weight_squared * (ones - diag).to(device)) ** 2).sum()
            
    return loss_orth * beta
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值