Pytorch学习Day05[连载]

一、发现问题

在搭建网络后,有时候会出现梯度爆炸或梯度消失,这是因为每个输出节点的标准差过大或者过小
先上源码

import torch
import torch.nn as nn
import random
import numpy as np

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(1)

class MLP(nn.Module):
	def __init__(self,neural_num,layers):
		super(MLP,self).__init__()
		self.linears=nn.ModuleList([nn.Linear(neural_num,neural_num,bias=False) for i in range(layers)])
		self.neural_num=neural_num
	
	def forward(self,x):
		for (i,linear) in enumerate(self.linears):
			x=linear(x)
			print("layer:{}, std:{}".format(i,x.std())) #D(X*Y)=D(X)*D(Y)
            if torch.isnan(x.std()):
                print('output is nan in {} layers'.format(i))
                break
		return x
	
	def initialize(self):
		for m in self.modules():
			if isinstance(m,nn.Linear):
				nn.init.normal_(m.weight.data) #采用标准正态分布 0均值 1标准差

layer_nums=100
neural_nums=256
batch_size=16

net=MLP(neural_nums,layer_nums)
net.initialize()
inputs=torch.randn((batch_size,neural_nums))
output=net(inputs)
print(output)

结果:
在这里插入图片描述
再细致一点
在这里插入图片描述

二、解决问题

1、无激活函数

在初始化权重在标准正态分布基础上对方差进行调整为1/n,即在initialize(self)那里调整

nn.init.normal_(m.weight.data)改为
nn.init.normal_(m.weight.data,std=np.sqrt(1/self.neural_num)) 

2、tanh激活函数

就调整为均匀分布
版本1:

nn.init.normal_(m.weight.data)改为

a=np.sqrt(6/(self.neural_num+self.neural_num))
tanh_gain=nn.init.calculate_gain('tanh')
a*=tanh_gain
nn.init.uniform_(m.weight.data,-a,a)

版本2:

nn.init.normal_(m.weight.data)改为

tanh_gain=nn.init.calculate_gain('tanh')
nn.init.xavier_uniform_(m.weight.data,gain=tanh_gain)

3、relu激活函数

版本1:

nn.init.normal_(m.weight.data,std=np.sqrt(2/self.neural_num))

版本2:

nn.init.kaiming_normal_(m.weight.data)

说明:方差增益函数是评价两个神经元之间的标准差减少的幅度。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值