1. LayerNorm 的本质和作用
1.1 LayerNorm 的本质
- 对单个数据的指定维度进行 Normalization (归一化)处理。
- 且指定的维度至少得包括最后一个维度。
为什么至少得包括最后一个维度?
- 最后一维通常对应于
特征维度
,对其进行归一化可以使特征在同一尺度上。- 对最后一维进行归一化可以确保每个样本的特征处理一致,
有助于模型的稳定性和收敛
。- 从实现角度来看,对最后一维进行归一化
更多符合张量操作的逻辑
,便于计算和实现。
1.2 LayerNorm 的作用
- 正向的 normalization,
让输入分布稳定
,能帮助模型收敛; - 和前向相比,norm 操作之中
因为均值和方差而引入的梯度
在稳定训练
中起到了更大的作用1。
2. LayerNorm 的代码实现
下列代码来源于博客2 ,本文增加一些对细节的解释。
import torch
import torch.nn as nn
def layer_norm_process(feature: torch.Tensor, beta=0., gamma=1., eps=1e-5):
var_mean = torch.var_mean(feature, dim = -1, unbiased = False)
mean = var_mean[1]
var = var_mean[0]
print('mean', mean.shape)
print('feature', feature.shape)
print('mean[..., None]', mean[..., None].shape)
feature = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)
feature = feature * gamma + beta
return feature
def main():
t = torch.rand(4, 2, 3)
print('t', t)
norm = nn.LayerNorm(normalized_shape=t.shape[-1], eps=1e-5)
t1 = norm(t)
t2 = layer_norm_process(t)
print('t1:\n', t1)
print('t2:\n', t2)
if __name__ == '__main__':
main()
- 为什么 mean = var_mean[1],var = var_mean[0]?
- 这是 torch.var_mean 默认的设置。
- 为什么要加
[..., None]
?- 为了让得到的 mean 和 var 能够增加一个维度、来适配 feature 的维度。
- PyTorch中的[…,None]语法用于向张量添加额外的维度,使其与广播兼容。
mean torch.Size([4, 2])
feature torch.Size([4, 2, 3])
mean[..., None] torch.Size([4, 2, 1])
更多关于 […, None] 的例子如下,几个 None 就是添加几个维度。
import torch
# Original tensor of shape (2, 3)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# Adding one dimension
x1 = x[..., None] # Shape: (2, 3, 1)
# Adding two dimensions
x2 = x[..., None, None] # Shape: (2, 3, 1, 1)
# Adding three dimensions
x3 = x[..., None, None, None] # Shape: (2, 3, 1, 1, 1)
print(x1.shape) # Output: torch.Size([2, 3, 1])
print(x2.shape) # Output: torch.Size([2, 3, 1, 1])
print(x3.shape) # Output: torch.Size([2, 3, 1, 1, 1])
https://tobiaslee.top/2019/11/21/understanding-layernorm/ ↩︎
https://blog.csdn.net/qq_37541097/article/details/117653177?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164570699716780366593622%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=164570699716780366593622&biz_id=0 ↩︎