pytorch(14)权值初始化

  1. 权值的方差过大导致梯度爆炸的原因
  2. 方差一致性原则分析Xavier方法与Kaiming初始化方法
    饱和激活函数tanh,非饱和激活函数relu
  3. pytorch提供的十种初始化方法

梯度消失与爆炸

\[H_2 = H_1 * W_2\\ \Delta W_2 = \frac{\partial Loss}{\partial W_2} =\frac{\partial Loss}{\partial out} *\frac{\partial out}{\partial H_2} *\frac{\partial H_2}{\partial W_2} =\frac{\partial Loss}{\partial out} *\frac{\partial out}{\partial H_2}*H_1 \]
\[{梯度消失:}H_1 \rightarrow 0 \Rightarrow \Delta W_2 \rightarrow 0\\ {梯度爆炸:}H_1 \rightarrow \infty \Rightarrow \Delta W_2 \rightarrow \infty \]
\[1. E(X*Y)=E(X)*E(Y)\\ 2. D(X)=E(X^2)-[E(X)]^2\\ 3. D(X+Y)=D(X)+D(Y)\\ 1.2.3. \Rightarrow D(X*Y)=D(X)D(Y)+D(X)*[E(Y)]^2+D(Y)*[E(X)]^2\\ 若E(X)=0,E(Y)=0 \Rightarrow D(X*Y)=D(X)*D(Y) \]
\[H_{11} = \sum ^{n}_{i=0} X_i * W_{1i}\\ D(X*Y) = D(X)*D(Y)\\ D(H_{11})=\sum ^{n}_{i=0} D(X_i)*D(W_1i)=n*(1*1)=n\\ std(H_{11})=\sqrt D(H_11) = \sqrt n\\ D(H_1) = n*D(X)*D(W)=1\\ D(W)=\frac{1}{n}\Rightarrow std(W)=\sqrt \frac {1}{n} \]

Xavier方法与Kaiming方法

Xavier初始化
方差一致性,保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh

\[n_i * D(W)=1\\ n_{i+1} *D(W)=1\\ \Rightarrow D(W)=\frac{2}{n_i+n_i+1} \]
\[W \sim U[-a,a]\\ D(W) = \frac {(-a-a)^2}{12} = \frac {(2a)^2}{12}=\frac {a^2}{3}\\ \frac{2}{n_i+n_{i+1}}=\frac{a^2}{3}\Rightarrow a = \frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}\\ \Rightarrow W \sim U[-\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}},\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}] \]

Kaiming初始化
方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种

\[D(W) = \frac{2}{n_i}\\ D(W) = \frac{2}{(1+a^2)*n_i}\\ std(W) = \sqrt{\frac{2}{(1+a^2)*n_i}} \]

参考文献:
《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》

常用初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化
nn.init.calculate_gain(nonlinearity, param=None)

功能:计算激活函数的方差变化尺度
输入数据的方差和输出数据方差的比例。
参数:

  • nonlinearity:激活函数名称
  • param:激活函数参数,Leaky ReLU的negative_slop
# -*- coding: utf-8 -*-
"""
# @file name  : grad_vanish_explod.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-09-30 10:08:00
# @brief      : 梯度消失与爆炸实验
"""
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import random
import numpy as np
import torch.nn as nn
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))

import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            x = torch.relu(x)

            print("layer:{}, std:{}".format(i, x.std()))
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))    # normal: mean=0, std=1

                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                #
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                #
                # nn.init.uniform_(m.weight.data, -a, a)

                # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
                nn.init.kaiming_normal_(m.weight.data)

flag = 0
# flag = 1

if flag:
    layer_nums = 100
    neural_nums = 256
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

    output = net(inputs)
    print(output)

# ======================================= calculate gain =======================================

# flag = 0
flag = 1

if flag:

    x = torch.randn(10000)
    out = torch.tanh(x)

    gain = x.std() / out.std()
    print('gain:{}'.format(gain))

    tanh_gain = nn.init.calculate_gain('tanh')
    print('tanh_gain in PyTorch:', tanh_gain)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值