前言
线性模型是最基本的模型,但包含了很多知识点,本文通过阅读pytorch的源码学习Linear层的构造和实现。
源码解读
常量声明和类型提示
class Linear(Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
__constants__
就是常量,是网络的超参数,放入__constants__
的作用是为jit
优化提供信息。
初始化方法
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
factory_kwargs = {'device': device, 'dtype': dtype}
仅仅是避免后面构造Parameter的时候重复写device和dtype,才构造了这个dict。
Parameter
声明如下。
CLASS torch.nn.parameter.Parameter(data=None, requires_grad=True)
代表模块参数的张量。
参数是张量(Tensor)的子类,在与Modules
一起使用时具有非常特殊的性质——当它们被分配为Module
属性时,它们会自动添加到其参数列表中,并会出现在parameters()
迭代器中。
Parameter
的requires_grad
会自动设置为True
register_parameter
为模型增加参数(Parameter),参数可以作为模型的属性通过名字访问。
注意这里的写法:self.register_parameter(‘bias’, None)
不能够直接self.bias = None
reset_parameters
初始化参数值,实现代码如下。
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
kaiming_uniform_
Linear默认使用
kaiming_uniform
初始化网络权重weight。
关于参数初始化方法,也是一个挺大的坑,后面再慢慢填吧。
关于Parameter的讨论
Stackoverflow上有一个很好的解答,发表于2018年6月,只是有些内容是针对pytorch老版本的,希望能够提供一些帮助。
注意:这里有关于Variable
的讨论,但是,现在Pytorch
废弃了Variable
。
回答翻译并摘录如下:
Tensor是多维矩阵。参数的原始形式就是Tensor,即多维矩阵。它是
Variable
类的子类。
(译注:现在不是了)
当与模块关联时,Variable
和Parameter
之间的区别就出现了。当参数作为模型属性与模块关联时,它会自动添加到参数列表中,并且可以使用“参数”迭代器进行访问。 (这句仍然正确)
最初在Torch
中,变量(可能是中间状态)也会在分配时作为模型的参数添加。后来确定了一些用例,其中确定了需要缓存变量而不是将它们添加到参数列表中。 (现在应该用buffer)