Pytorch学习笔记——Linear模型源码学习

前言

线性模型是最基本的模型,但包含了很多知识点,本文通过阅读pytorch的源码学习Linear层的构造和实现。

源码解读

常量声明和类型提示

class Linear(Module):

    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

__constants__
就是常量,是网络的超参数,放入__constants__的作用是为jit优化提供信息。

初始化方法

def __init__(self, in_features: int, out_features: int, bias: bool = True,
                device=None, dtype=None) -> None:
    factory_kwargs = {'device': device, 'dtype': dtype}
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
    if bias:
        self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
    else:
        self.register_parameter('bias', None)
    self.reset_parameters()

factory_kwargs = {'device': device, 'dtype': dtype}

仅仅是避免后面构造Parameter的时候重复写device和dtype,才构造了这个dict。

Parameter

声明如下。

CLASS torch.nn.parameter.Parameter(data=None, requires_grad=True)

代表模块参数的张量。
参数是张量(Tensor)的子类,在与Modules一起使用时具有非常特殊的性质——当它们被分配为Module属性时,它们会自动添加到其参数列表中,并会出现在parameters()迭代器中。
Parameterrequires_grad会自动设置为True

register_parameter

为模型增加参数(Parameter),参数可以作为模型的属性通过名字访问。
注意这里的写法:self.register_parameter(‘bias’, None)
不能够直接self.bias = None

reset_parameters

初始化参数值,实现代码如下。

def reset_parameters(self) -> None:
    # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
    # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
    # https://github.com/pytorch/pytorch/issues/57109
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    if self.bias is not None:
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

kaiming_uniform_

Linear默认使用kaiming_uniform初始化网络权重weight。

关于参数初始化方法,也是一个挺大的坑,后面再慢慢填吧。

关于Parameter的讨论

Stackoverflow上有一个很好的解答,发表于2018年6月,只是有些内容是针对pytorch老版本的,希望能够提供一些帮助。

注意:这里有关于Variable的讨论,但是,现在Pytorch废弃了Variable

回答翻译并摘录如下:

Tensor是多维矩阵。参数的原始形式就是Tensor,即多维矩阵。它是Variable类的子类。
(译注:现在不是了)
当与模块关联时,VariableParameter之间的区别就出现了。当参数作为模型属性与模块关联时,它会自动添加到参数列表中,并且可以使用“参数”迭代器进行访问。 (这句仍然正确)
最初在Torch中,变量(可能是中间状态)也会在分配时作为模型的参数添加。后来确定了一些用例,其中确定了需要缓存变量而不是将它们添加到参数列表中。 (现在应该用buffer)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值