本文主要讲述最简单的线性回归函数,个人理解定义一个nn.Linear就相当于定义下面的函数:
讲解上述公式在pytorch的实现,主要包括nn.Linear的源码解读和实例展示。
1. nn.Linear 源码解读
先看一下Linear类的实现:源码地址
Linear继承于nn.Module,内部函数主要有__init__,reset_parameters, forward和 extra_repr函数
,下面是部分源码:
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features)