Pytorch - nn.Linear

Ctrl并点击函数,可以看到nn.Linear源码:

class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

nn.Linear继承于nn.Module,内部函数主要有__init__reset_parameters, forwardextra_repr函数。

__init__(self, in_features, out_features, bias=True)
in_features:前一层网络神经元的个数
out_features: 该网络层神经元的个数

注释:
Applies a linear transformation to the incoming data,
math:y = xA^T + b

Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn an additive bias.
Default: True

Attributes:(nn.linear参数)
weight, bias

Examples:
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

import torch
x = torch.randn(128, 20)  # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) 
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)  
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))

在这里插入图片描述
在这里插入图片描述
nn.Linear(20, 30) :
x的维度是输入维度:(128,20)
w的维度(公式中相当于A)是:(30,20)
b的维度是30
输出维度是:(128,30)

参考:
[1] pytorch系列 —5以 linear_regression为例讲解神经网络实现基本步骤以及解读nn.Linear函数:https://blog.csdn.net/dss_dssssd/article/details/83892824
[2] torch.nn.Linear()函数的理解:https://blog.csdn.net/m0_37586991/article/details/87861418

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值