Weight Normalization(WN) 权重归一化

      BN/LN/IN/GN都是在数据的层面上做的归一化,而Weight Normalization(WN)是对网络权值W做的归一化。WN的做法是将权值向量w在其欧氏范数和其方向上解耦成了参数向量 v 和参数标量 g 后使用SGD分别优化这两个参数。

      WN也是和样本量无关的,所以可以应用在batchsize较小以及RNN等动态网络中;另外BN使用的基于mini-batch的归一化统计量代替全局统计量,相当于在梯度计算中引入了噪声。而WN则没有这个问题,所以在生成模型,强化学习等噪声敏感的环境中WN的效果也要优于BN。

      WN没有额外参数,这样更节约显存。同时WN的计算效率也要优于要计算归一化统计量的BN。

      但是,WN不具备BN把每一层的输出Y固定在一个变化范围的作用。因此采用WN的时候要特别注意参数初始值的选择

可以认为v是本来的权重

v除以v的模,可以得到它的单位方向向量,再乘以g,g是可学习的

 本来的权重是v的,现在又新增了一个g,得到的新的w是保留了v的方向,然后又新增了一个可学习的幅度

torch.nn.utils.weight_norm(module, name='weight', dim=0)
import torch
from torch import nn
layer = nn.Linear(20, 40)
m = nn.utils.weight_norm(layer, name='weight')
print(m)
print(m.weight_g.size())
print(m.weight_v.size())

手动实现

import torch
from torch import nn
input = torch.randn(8, 3, 20)
linear = nn.Linear(20, 40, bias=False)
wn_layer = nn.utils.weight_norm(linear, name='weight')
wn_output = wn_layer(input)

weight_direction = linear.weight / torch.norm(linear.weight, p=2, dim=1, keepdim=True) #二范数
weight_magnitude = wn_layer.weight_g
output = input @ (weight_direction.permute(1,0).contiguous() * weight_magnitude.permute(1,0).contiguous())
assert torch.allclose(wn_output, output)

  • 5
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值