深度学习:原理简明教程07-深度学习:初始化函数——作者:Ling,作者链接: http://www.bdpt.net/cn
欢迎转载,作者:Ling,注明出处:深度学习:原理简明教程07-深度学习:初始化函数
本文主要讨论深度学习中初始化函数的问题。
初始化函数主要作用:对所有参数给一个初始值,如前文中的w和b,给一个初始值,之后可以根据cost函数调整w和b的值。
为什么需要初始化函数?
下面以一个简单的神经网络为例,如果初始化全部参数为0,会是什么情况?
实例:
一个实例,进行一次正向反向求值:
结论:
1)w[1],b[1]没变化
2)w[2],b[2]两列变化一样
解决办法:
用随机初始化或者其他初始化办法,让所有参数不一样,这样学习效果更佳。
主要初始化方法:
Zeros()
Ones()
Constant(value=0)
RandomNormal(mean=0.0, stddev=0.05, seed=None)
RandomUniform(minval=-0.05, maxval=0.05, seed=None)
TruncatedNormal(mean=0.0, stddev=0.05, seed=None)
VarianceScaling(scale=1.0, mode='fan_in', distribution='normal', seed=None)
Orthogonal(gain=1.0, seed=None)
Identity(gain=1.0)
glorot_normal(seed=None)
he_normal(seed=None)
lecun_normal(seed=None)
he_uniform(seed=None)
glorot_uniform(seed=None)
lecun_uniform(seed=None)
Zeros:
全部初始化为0
Ones:
全部初始化为1
Constant:
全初始化为固定值
RandomNormal:
根据正太分布初始化
RandomUniform:
指定最大最小值,然后再区间随机初始化
RandomUniform:
均匀分布随机初始化
TruncatedNormal:
和RandomNormal类似,但是会根据stddev的两倍进行截断:
例如:
当输入参数mean = 0 , stddev =1时
输出是不可能出现[-2,2]以外的点的值
该初始化方法是被推荐使用的初始化方法
VarianceScaling:
scale: 缩放尺度(正浮点数)
mode: "fan_in", "fan_out", "fan_avg"中的一个,用于计算标准差stddev的值。
distribution:分布类型,"normal"或“uniform"中的一个。
当 distribution="normal" 的时候,生成truncated normal distribution(截断正态分布) 的随机数,其中stddev = sqrt(scale / n) ,n的计算与mode参数有关。
如果mode = "fan_in", n为输入单元的结点数;
如果mode = "fan_out",n为输出单元的结点数;
如果mode = "fan_avg",n为输入和输出单元结点数的平均值。
当distribution="uniform”的时候 ,生成均匀分布的随机数,假设分布区间为[-limit, limit],则
limit = sqrt(3 * scale / n)
Orthogonal:
生成正交矩阵的随机数。
当需要生成的参数是2维时,这个正交矩阵是由均匀分布的随机数矩阵经过SVD分解而来。
gain是最后矩阵乘以的系数
Identity:
对角元素都是1,其他都是0的矩阵
gain是乘以该矩阵的系数
[[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]]
只被用于2D的矩阵
lecun_uniform:
实际上是:
VarianceScaling(scale=1.,
mode='fan_in',
distribution='uniform',
seed=seed)
lecun_normal:
实际上是:
VarianceScaling(scale=1.,
mode='fan_in',
distribution='normal',
seed=seed)
he_normal:
实际上是:
VarianceScaling(scale=2.,
mode='fan_in',
distribution='normal',
seed=seed)
he_uniform:
实际上是:
VarianceScaling(scale=2.,
mode='fan_in',
distribution='uniform',
seed=seed)
glorot_normal:
实际上是:
VarianceScaling(scale=1.,
mode='fan_avg',
distribution='normal',
seed=seed)
glorot_uniform:
VarianceScaling(scale=1.,
mode='fan_avg',
distribution='uniform',
seed=seed)