BN/LN/IN/GN都是在数据的层面上做的归一化,而Weight Normalization(WN)是对网络权值W做的归一化。WN的做法是将权值向量w在其欧氏范数和其方向上解耦成了参数向量 v 和参数标量 g 后使用SGD分别优化这两个参数。
WN也是和样本量无关的,所以可以应用在batchsize较小以及RNN等动态网络中;另外BN使用的基于mini-batch的归一化统计量代替全局统计量,相当于在梯度计算中引入了噪声。而WN则没有这个问题,所以在生成模型,强化学习等噪声敏感的环境中WN的效果也要优于BN。
WN没有额外参数,这样更节约显存。同时WN的计算效率也要优于要计算归一化统计量的BN。
但是,WN不具备BN把每一层的输出Y固定在一个变化范围的作用。因此采用WN的时候要特别注意参数初始值的选择
可以认为v是本来的权重
v除以v的模,可以得到它的单位方向向量,再乘以g,g是可学习的
本来的权重是v的,现在又新增了一个g,得到的新的w是保留了v的方向,然后又新增了一个可学习的幅度
torch.nn.utils.weight_norm(module, name='weight', dim=0)
import torch from torch import nn layer = nn.Linear(20, 40) m = nn.utils.weight_norm(layer, name='weight') print(m) print(m.weight_g.size()) print(m.weight_v.size())
手动实现
import torch from torch import nn input = torch.randn(8, 3, 20) linear = nn.Linear(20, 40, bias=False) wn_layer = nn.utils.weight_norm(linear, name='weight') wn_output = wn_layer(input) weight_direction = linear.weight / torch.norm(linear.weight, p=2, dim=1, keepdim=True) #二范数 weight_magnitude = wn_layer.weight_g output = input @ (weight_direction.permute(1,0).contiguous() * weight_magnitude.permute(1,0).contiguous()) assert torch.allclose(wn_output, output)
Weight Normalization(WN) 权重归一化
于 2022-04-10 12:01:32 首次发布