2021-02-24:21-第四周-第一节-权值初始化

 
# -*- coding: utf-8 -*-
"""
# @file name  : grad_canish_explod.py
# @author     : yunfeiGuo
# @date       : 2021-2-19 16:44:00
# @brief      : 权值初始化
"""
import os
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
# from tools.common_tools import transform_invert, set_seed
import numpy as np
import torch
class MLP(nn.Module):
    def __init__(self,neural_num,layers):
        super(MLP, self).__init__()
        #结合ModuleList和list生成器——生成一百层线性层的生成
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for(i, linear) in enumerate(self.linears):
            x = linear(x)

            print("layer:{}, std:{}".format(i, x.std()))
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))  #normal:mean =0,std = 1标准正态分布
#flag = 0
flag = 1

if flag:
    layer_nums = 100
    neural_nums = 400
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums)) #normal:mean = 0,std = 1

    output = net(inputs)
    print(output)

结果:
 

layer:0, std:0.9921884536743164
layer:1, std:1.0107942819595337
layer:2, std:1.0107700824737549
layer:3, std:1.0103548765182495
layer:4, std:1.0256924629211426
layer:5, std:1.0307409763336182
layer:6, std:1.0156118869781494
layer:7, std:1.00758957862854
layer:8, std:0.992605984210968
layer:9, std:0.9987878799438477
layer:10, std:0.9998607039451599
layer:11, std:1.0000765323638916
layer:12, std:1.0003615617752075
layer:13, std:0.9830509424209595
layer:14, std:0.9736397862434387
layer:15, std:0.9665453433990479
layer:16, std:0.951435923576355
layer:17, std:0.9500362873077393
layer:18, std:0.9513030648231506
layer:19, std:0.9312620759010315
layer:20, std:0.9206184148788452
layer:21, std:0.9264180660247803
layer:22, std:0.9173408150672913
layer:23, std:0.9214672446250916
layer:24, std:0.9276127815246582
layer:25, std:0.914120078086853
layer:26, std:0.9065655469894409
layer:27, std:0.8998688459396362
layer:28, std:0.8938019275665283
layer:29, std:0.9010204076766968
layer:30, std:0.9179297685623169
layer:31, std:0.9036451578140259
layer:32, std:0.8873993754386902
layer:33, std:0.8862268924713135
layer:34, std:0.8831592202186584
layer:35, std:0.8924071192741394
layer:36, std:0.9000980257987976
layer:37, std:0.8906688094139099
layer:38, std:0.8929744958877563
layer:39, std:0.8799082040786743
layer:40, std:0.8729921579360962
layer:41, std:0.8602384328842163
layer:42, std:0.8591974377632141
layer:43, std:0.835864782333374
layer:44, std:0.8369930386543274
layer:45, std:0.8339517116546631
layer:46, std:0.8424709439277649
layer:47, std:0.8422181010246277
layer:48, std:0.8254409432411194
layer:49, std:0.824454665184021
layer:50, std:0.8377672433853149
layer:51, std:0.8348366022109985
layer:52, std:0.84702467918396
layer:53, std:0.8349087238311768
layer:54, std:0.8313407301902771
layer:55, std:0.8160363435745239
layer:56, std:0.8238162398338318
layer:57, std:0.8085090517997742
layer:58, std:0.8114701509475708
layer:59, std:0.8043695688247681
layer:60, std:0.7957208752632141
layer:61, std:0.8016817569732666
layer:62, std:0.799397349357605
layer:63, std:0.8005610108375549
layer:64, std:0.81252521276474
layer:65, std:0.8067101240158081
layer:66, std:0.7915506958961487
layer:67, std:0.8123825192451477
layer:68, std:0.8054995536804199
layer:69, std:0.8056119084358215
layer:70, std:0.8105517625808716
layer:71, std:0.7931839227676392
layer:72, std:0.8052643537521362
layer:73, std:0.8095407485961914
layer:74, std:0.8173641562461853
layer:75, std:0.818963885307312
layer:76, std:0.8219993710517883
layer:77, std:0.8288534879684448
layer:78, std:0.8074101805686951
layer:79, std:0.8258078098297119
layer:80, std:0.8096858263015747
layer:81, std:0.836383044719696
layer:82, std:0.8595016002655029
layer:83, std:0.8769122362136841
layer:84, std:0.8849483132362366
layer:85, std:0.9054016470909119
layer:86, std:0.9354943037033081
layer:87, std:0.9218170046806335
layer:88, std:0.9264342188835144
layer:89, std:0.9195851683616638
layer:90, std:0.9228088855743408
layer:91, std:0.9145572781562805
layer:92, std:0.9412762522697449
layer:93, std:0.9753511548042297
layer:94, std:0.9806225299835205
layer:95, std:0.9488538503646851
layer:96, std:0.9764752388000488
layer:97, std:0.9379482269287109
layer:98, std:0.9235386848449707
layer:99, std:0.900111198425293
tensor([[-0.6585,  0.7437, -0.3855,  ...,  1.1179, -1.2283, -1.3067],
        [ 0.5084, -0.7306,  0.2649,  ..., -0.0687,  0.0195,  0.7177],
        [ 0.4830, -0.6636, -0.4006,  ..., -0.4150,  0.0139,  0.3082],
        ...,
        [ 0.8793, -0.7996,  0.3099,  ..., -0.6950,  1.5076,  0.9907],
        [ 0.6845, -0.0335,  0.7977,  ..., -2.0164,  0.1546,  1.4978],
        [ 0.1866, -1.2768, -0.3950,  ...,  0.0869,  0.6206,  0.8633]],
       grad_fn=<MmBackward>)

Process finished with exit code 0

修改:

结果2:出现梯度消失

layer:0, std:0.6390261054039001
layer:1, std:0.49473661184310913
layer:2, std:0.4166508913040161
layer:3, std:0.36537501215934753
layer:4, std:0.3223038911819458
layer:5, std:0.2981663644313812
layer:6, std:0.27496904134750366
layer:7, std:0.25671830773353577
layer:8, std:0.23955628275871277
layer:9, std:0.22771558165550232
layer:10, std:0.21411724388599396
layer:11, std:0.20493920147418976
layer:12, std:0.1959180235862732
layer:13, std:0.18843482434749603
layer:14, std:0.1841781586408615
layer:15, std:0.1790604293346405
layer:16, std:0.17309771478176117
layer:17, std:0.1665753871202469
layer:18, std:0.1616308093070984
layer:19, std:0.15272243320941925
layer:20, std:0.14441688358783722
layer:21, std:0.14256758987903595
layer:22, std:0.14040584862232208
layer:23, std:0.1384734958410263
layer:24, std:0.13550491631031036
layer:25, std:0.13337983191013336
layer:26, std:0.12921476364135742
layer:27, std:0.12870466709136963
layer:28, std:0.12815523147583008
layer:29, std:0.12827163934707642
layer:30, std:0.12512841820716858
layer:31, std:0.12404870241880417
layer:32, std:0.12086830288171768
layer:33, std:0.11759477108716965
layer:34, std:0.11463416367769241
layer:35, std:0.11336363852024078
layer:36, std:0.11248839646577835
layer:37, std:0.11057556420564651
layer:38, std:0.10894306749105453
layer:39, std:0.1080525740981102
layer:40, std:0.1083245798945427
layer:41, std:0.10841522365808487
layer:42, std:0.10763368010520935
layer:43, std:0.10482244193553925
layer:44, std:0.10101202875375748
layer:45, std:0.09761253744363785
layer:46, std:0.09689796715974808
layer:47, std:0.09653888642787933
layer:48, std:0.0945930927991867
layer:49, std:0.09525361657142639
layer:50, std:0.09581011533737183
layer:51, std:0.09559603035449982
layer:52, std:0.09250099211931229
layer:53, std:0.09044396132230759
layer:54, std:0.09124993532896042
layer:55, std:0.08985602110624313
layer:56, std:0.0886460393667221
layer:57, std:0.08779852092266083
layer:58, std:0.08917854726314545
layer:59, std:0.08753976225852966
layer:60, std:0.08559731394052505
layer:61, std:0.08475394546985626
layer:62, std:0.08390384912490845
layer:63, std:0.08435895293951035
layer:64, std:0.08321720361709595
layer:65, std:0.08144424110651016
layer:66, std:0.07844565808773041
layer:67, std:0.07755745947360992
layer:68, std:0.07360151410102844
layer:69, std:0.07208137214183807
layer:70, std:0.06971283257007599
layer:71, std:0.06869462132453918
layer:72, std:0.06869859993457794
layer:73, std:0.06708640605211258
layer:74, std:0.06695549190044403
layer:75, std:0.06639198213815689
layer:76, std:0.06611380726099014
layer:77, std:0.06629636138677597
layer:78, std:0.06580130755901337
layer:79, std:0.06554996222257614
layer:80, std:0.06589429825544357
layer:81, std:0.06597837060689926
layer:82, std:0.06580808013677597
layer:83, std:0.06501759588718414
layer:84, std:0.06554077565670013
layer:85, std:0.06526520103216171
layer:86, std:0.06393258273601532
layer:87, std:0.063775934278965
layer:88, std:0.06375069916248322
layer:89, std:0.06455789506435394
layer:90, std:0.06108592078089714
layer:91, std:0.06084003299474716
layer:92, std:0.059918470680713654
layer:93, std:0.06046328321099281
layer:94, std:0.059151582419872284
layer:95, std:0.05842287465929985
layer:96, std:0.057261671870946884
layer:97, std:0.05646803602576256
layer:98, std:0.05618279054760933
layer:99, std:0.056390754878520966
tensor([[-0.1123,  0.0819,  0.0306,  ..., -0.0470, -0.0046, -0.1025],
        [ 0.0487, -0.0020,  0.0342,  ..., -0.0359, -0.0736,  0.0354],
        [-0.0282,  0.0580,  0.0094,  ...,  0.0653, -0.0048,  0.0259],
        ...,
        [-0.0775, -0.0103, -0.0550,  ...,  0.0482,  0.0015,  0.0910],
        [ 0.0167,  0.0200,  0.0109,  ..., -0.0032,  0.1072, -0.0674],
        [ 0.0229, -0.0024,  0.0324,  ..., -0.0963, -0.0460, -0.0453]],
       grad_fn=<TanhBackward>)

Process finished with exit code 0

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值