# Pytorch神经网络初始化kaiming分布

## 函数的增益值

torch.nn.init.calculate_gain(nonlinearity, param=None)

## fan_in和fan_out

pytorch计算fan_in和fan_out的源码

def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.ndimension()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed
for tensor with fewer than 2 dimensions")

if dimensions == 2:  # Linear
fan_in = tensor.size(1)
fan_out = tensor.size(0)
else:
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size

return fan_in, fan_out


## xavier分布

xavier分布解析：https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/

xavier初始化可以使得输入值 x x 方差经过网络层后的输出值 y y 方差不变。
（1）xavier的均匀分布

torch.nn.init.xavier_uniform_(tensor, gain=1)


>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))


(2)xavier正态分布

torch.nn.init.xavier_normal_(tensor, gain=1)


## kaiming分布

Xavier在tanh中表现的很好，但在Relu激活函数中表现的很差，所何凯明提出了针对于relu的初始化方法。pytorch默认使用kaiming正态分布初始化卷积层参数
(1)kaiming均匀分布

torch.nn.init.kaiming_uniform_
(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')


b o u n d = 6 ( 1 + a 2 ) × f a n _ i n bound=\sqrt{\frac{6}{(1+a^{2})\times fan\_in}}

• a – the negative slope of the rectifier used after this layer (0 for ReLU by default).激活函数的负斜率，
• mode – either ‘fan_in’ (default) or ‘fan_out’. Choosing fan_in
preserves the magnitude of the variance of the weights in the forward
pass. Choosing fan_out preserves the magnitudes in the backwards
pass.默认为fan_in模式，fan_in可以保持前向传播的权重方差的数量级，fan_out可以保持反向传播的权重方差的数量级。
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')


(2)kaiming正态分布

torch.nn.init.kaiming_normal_
(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')


s t d = 2 ( 1 + a 2 ) × f a n _ i n std=\sqrt{\frac{2}{(1+a^{2})\times fan\_in}}

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')


06-13 2万+
06-30 3万+
04-10 4192
03-10 1万+
12-14 1万+
08-08 3401
11-22 1万+
11-02 830
12-12 588