# -*- coding: utf-8 -*-
"""
# @file name : grad_canish_explod.py
# @author : yunfeiGuo
# @date : 2021-2-19 16:44:00
# @brief : 权值初始化
"""
import os
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
# from tools.common_tools import transform_invert, set_seed
import numpy as np
import torch
class MLP(nn.Module):
def __init__(self,neural_num,layers):
super(MLP, self).__init__()
#结合ModuleList和list生成器——生成一百层线性层的生成
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for(i, linear) in enumerate(self.linears):
x = linear(x)
print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num)) #normal:mean =0,std = 1标准正态分布
#flag = 0
flag = 1
if flag:
layer_nums = 100
neural_nums = 400
batch_size = 16
net = MLP(neural_nums, layer_nums)
net.initialize()
inputs = torch.randn((batch_size, neural_nums)) #normal:mean = 0,std = 1
output = net(inputs)
print(output)
结果:
layer:0, std:0.9921884536743164
layer:1, std:1.0107942819595337
layer:2, std:1.0107700824737549
layer:3, std:1.0103548765182495
layer:4, std:1.0256924629211426
layer:5, std:1.0307409763336182
layer:6, std:1.0156118869781494
layer:7, std:1.00758957862854
layer:8, std:0.992605984210968
layer:9, std:0.9987878799438477
layer:10, std:0.9998607039451599
layer:11, std:1.0000765323638916
layer:12, std:1.0003615617752075
layer:13, std:0.9830509424209595
layer:14, std:0.9736397862434387
layer:15, std:0.9665453433990479
layer:16, std:0.951435923576355
layer:17, std:0.9500362873077393
layer:18, std:0.9513030648231506
layer:19, std:0.9312620759010315
layer:20, std:0.9206184148788452
layer:21, std:0.9264180660247803
layer:22, std:0.9173408150672913
layer:23, std:0.9214672446250916
layer:24, std:0.9276127815246582
layer:25, std:0.914120078086853
layer:26, std:0.9065655469894409
layer:27, std:0.8998688459396362
layer:28, std:0.8938019275665283
layer:29, std:0.9010204076766968
layer:30, std:0.9179297685623169
layer:31, std:0.9036451578140259
layer:32, std:0.8873993754386902
layer:33, std:0.8862268924713135
layer:34, std:0.8831592202186584
layer:35, std:0.8924071192741394
layer:36, std:0.9000980257987976
layer:37, std:0.8906688094139099
layer:38, std:0.8929744958877563
layer:39, std:0.8799082040786743
layer:40, std:0.8729921579360962
layer:41, std:0.8602384328842163
layer:42, std:0.8591974377632141
layer:43, std:0.835864782333374
layer:44, std:0.8369930386543274
layer:45, std:0.8339517116546631
layer:46, std:0.8424709439277649
layer:47, std:0.8422181010246277
layer:48, std:0.8254409432411194
layer:49, std:0.824454665184021
layer:50, std:0.8377672433853149
layer:51, std:0.8348366022109985
layer:52, std:0.84702467918396
layer:53, std:0.8349087238311768
layer:54, std:0.8313407301902771
layer:55, std:0.8160363435745239
layer:56, std:0.8238162398338318
layer:57, std:0.8085090517997742
layer:58, std:0.8114701509475708
layer:59, std:0.8043695688247681
layer:60, std:0.7957208752632141
layer:61, std:0.8016817569732666
layer:62, std:0.799397349357605
layer:63, std:0.8005610108375549
layer:64, std:0.81252521276474
layer:65, std:0.8067101240158081
layer:66, std:0.7915506958961487
layer:67, std:0.8123825192451477
layer:68, std:0.8054995536804199
layer:69, std:0.8056119084358215
layer:70, std:0.8105517625808716
layer:71, std:0.7931839227676392
layer:72, std:0.8052643537521362
layer:73, std:0.8095407485961914
layer:74, std:0.8173641562461853
layer:75, std:0.818963885307312
layer:76, std:0.8219993710517883
layer:77, std:0.8288534879684448
layer:78, std:0.8074101805686951
layer:79, std:0.8258078098297119
layer:80, std:0.8096858263015747
layer:81, std:0.836383044719696
layer:82, std:0.8595016002655029
layer:83, std:0.8769122362136841
layer:84, std:0.8849483132362366
layer:85, std:0.9054016470909119
layer:86, std:0.9354943037033081
layer:87, std:0.9218170046806335
layer:88, std:0.9264342188835144
layer:89, std:0.9195851683616638
layer:90, std:0.9228088855743408
layer:91, std:0.9145572781562805
layer:92, std:0.9412762522697449
layer:93, std:0.9753511548042297
layer:94, std:0.9806225299835205
layer:95, std:0.9488538503646851
layer:96, std:0.9764752388000488
layer:97, std:0.9379482269287109
layer:98, std:0.9235386848449707
layer:99, std:0.900111198425293
tensor([[-0.6585, 0.7437, -0.3855, ..., 1.1179, -1.2283, -1.3067],
[ 0.5084, -0.7306, 0.2649, ..., -0.0687, 0.0195, 0.7177],
[ 0.4830, -0.6636, -0.4006, ..., -0.4150, 0.0139, 0.3082],
...,
[ 0.8793, -0.7996, 0.3099, ..., -0.6950, 1.5076, 0.9907],
[ 0.6845, -0.0335, 0.7977, ..., -2.0164, 0.1546, 1.4978],
[ 0.1866, -1.2768, -0.3950, ..., 0.0869, 0.6206, 0.8633]],
grad_fn=<MmBackward>)
Process finished with exit code 0
修改:
结果2:出现梯度消失
layer:0, std:0.6390261054039001
layer:1, std:0.49473661184310913
layer:2, std:0.4166508913040161
layer:3, std:0.36537501215934753
layer:4, std:0.3223038911819458
layer:5, std:0.2981663644313812
layer:6, std:0.27496904134750366
layer:7, std:0.25671830773353577
layer:8, std:0.23955628275871277
layer:9, std:0.22771558165550232
layer:10, std:0.21411724388599396
layer:11, std:0.20493920147418976
layer:12, std:0.1959180235862732
layer:13, std:0.18843482434749603
layer:14, std:0.1841781586408615
layer:15, std:0.1790604293346405
layer:16, std:0.17309771478176117
layer:17, std:0.1665753871202469
layer:18, std:0.1616308093070984
layer:19, std:0.15272243320941925
layer:20, std:0.14441688358783722
layer:21, std:0.14256758987903595
layer:22, std:0.14040584862232208
layer:23, std:0.1384734958410263
layer:24, std:0.13550491631031036
layer:25, std:0.13337983191013336
layer:26, std:0.12921476364135742
layer:27, std:0.12870466709136963
layer:28, std:0.12815523147583008
layer:29, std:0.12827163934707642
layer:30, std:0.12512841820716858
layer:31, std:0.12404870241880417
layer:32, std:0.12086830288171768
layer:33, std:0.11759477108716965
layer:34, std:0.11463416367769241
layer:35, std:0.11336363852024078
layer:36, std:0.11248839646577835
layer:37, std:0.11057556420564651
layer:38, std:0.10894306749105453
layer:39, std:0.1080525740981102
layer:40, std:0.1083245798945427
layer:41, std:0.10841522365808487
layer:42, std:0.10763368010520935
layer:43, std:0.10482244193553925
layer:44, std:0.10101202875375748
layer:45, std:0.09761253744363785
layer:46, std:0.09689796715974808
layer:47, std:0.09653888642787933
layer:48, std:0.0945930927991867
layer:49, std:0.09525361657142639
layer:50, std:0.09581011533737183
layer:51, std:0.09559603035449982
layer:52, std:0.09250099211931229
layer:53, std:0.09044396132230759
layer:54, std:0.09124993532896042
layer:55, std:0.08985602110624313
layer:56, std:0.0886460393667221
layer:57, std:0.08779852092266083
layer:58, std:0.08917854726314545
layer:59, std:0.08753976225852966
layer:60, std:0.08559731394052505
layer:61, std:0.08475394546985626
layer:62, std:0.08390384912490845
layer:63, std:0.08435895293951035
layer:64, std:0.08321720361709595
layer:65, std:0.08144424110651016
layer:66, std:0.07844565808773041
layer:67, std:0.07755745947360992
layer:68, std:0.07360151410102844
layer:69, std:0.07208137214183807
layer:70, std:0.06971283257007599
layer:71, std:0.06869462132453918
layer:72, std:0.06869859993457794
layer:73, std:0.06708640605211258
layer:74, std:0.06695549190044403
layer:75, std:0.06639198213815689
layer:76, std:0.06611380726099014
layer:77, std:0.06629636138677597
layer:78, std:0.06580130755901337
layer:79, std:0.06554996222257614
layer:80, std:0.06589429825544357
layer:81, std:0.06597837060689926
layer:82, std:0.06580808013677597
layer:83, std:0.06501759588718414
layer:84, std:0.06554077565670013
layer:85, std:0.06526520103216171
layer:86, std:0.06393258273601532
layer:87, std:0.063775934278965
layer:88, std:0.06375069916248322
layer:89, std:0.06455789506435394
layer:90, std:0.06108592078089714
layer:91, std:0.06084003299474716
layer:92, std:0.059918470680713654
layer:93, std:0.06046328321099281
layer:94, std:0.059151582419872284
layer:95, std:0.05842287465929985
layer:96, std:0.057261671870946884
layer:97, std:0.05646803602576256
layer:98, std:0.05618279054760933
layer:99, std:0.056390754878520966
tensor([[-0.1123, 0.0819, 0.0306, ..., -0.0470, -0.0046, -0.1025],
[ 0.0487, -0.0020, 0.0342, ..., -0.0359, -0.0736, 0.0354],
[-0.0282, 0.0580, 0.0094, ..., 0.0653, -0.0048, 0.0259],
...,
[-0.0775, -0.0103, -0.0550, ..., 0.0482, 0.0015, 0.0910],
[ 0.0167, 0.0200, 0.0109, ..., -0.0032, 0.1072, -0.0674],
[ 0.0229, -0.0024, 0.0324, ..., -0.0963, -0.0460, -0.0453]],
grad_fn=<TanhBackward>)
Process finished with exit code 0