pytorch nn.Module调用过程详解及weight和bias的值的初始化

首先说明一点:

nn.Module 是所有神经网络单元(neural network modules)的基类
pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数。

举例说明:

x = torch.randn(2, 3) #input 2*3
m = torch.nn.Linear(3, 2)#output 2*2
output = m(x)
print(output)
输出结果:
	tensor([[ 0.1918, -0.1055],
    [ 0.0589,  0.7517]], grad_fn=<AddmmBackward>)

m(x)开始调用线性图存nn.Linear,实际上调用的是基类nn.Module的__call__函数。如下:
在这里插入图片描述
而__call__函数是在红框处调用的forward函数,forward实现的地方在linear.py里面,如下:

@weak_script_method
def forward(self, input):
    return F.linear(input, self.weight, self.bias)

实际上返回的就是:input*weight + bias。

x = torch.randn(2, 3) #input 2*3
m = torch.nn.Linear(3, 2)#out_features 2*2
output = m(x)
print(output)
输出结果:
	tensor([[ 0.1918, -0.1055],
    [ 0.0589,  0.7517]], grad_fn=<AddmmBackward>)

这份代码里面就是:[2,3] *[3,2]最终结果为[2,2]

weight和bias的初始化在linear.py里面,如下:

   def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

W在 U(-bound,bound)中采样,其中:
在这里插入图片描述
fan_in即为W的第二维大小,即Linear所作用的输入向量的维度
bias也在U(−bound,bound)中采样,且bound与W一样。查看weight和bias的值,如下:

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(in_features=1, out_features=1)

    def forward(self, x):
        out = self.linear(x)
        return out

model = LinearRegression()

for layer,param in model.state_dict().items(): # param is weight or bias(Tensor)
	print(layer,param)

输出结果:
linear.weight tensor([[-0.8357]])
linear.bias tensor([-0.9367])
  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值