pytorch.nn.Linear
是一个类,下面是它的一些初始化参数
in_features
: 输入样本的张量大小
out_features
: 输出样本的张量大小
bias
: 偏置
它主要是对输入数据做一个线性变换。
y
=
x
A
T
+
b
y=xA^T+b
y=xAT+b
这里A是权重矩阵,b是偏置。他们都是根据 in_features
生成
测试代码:
m = torch.nn.Linear(2, 3)
input = torch.randn(4, 2)
out = m(input)
print(m.weight.shape)
print(m.bias.shape)
print(out.size())
输出:
torch.Size([3, 2])
torch.Size([3])
torch.Size([4, 3])
下面从数学角度解析:
首先输入数据是 [batchsize, in_features]
,这里我们设定的是 4X2
的张量。数学角度来说就是4X2的矩阵。
假设为:
x
=
[
1
2
2
3
2
2
1
1
]
x = \left[ \begin{matrix} 1&2 \\ 2&3 \\ 2&2 \\ 1&1 \end{matrix}\right]
x=⎣⎢⎢⎡12212321⎦⎥⎥⎤
那么这里的每一行数据 [1,2],[2,3][2,2],[1,1]
都代表一个样例,列数代表每一个样例的特征数,也就是 in_features
。所以这个矩阵代表的含义就是:输入五个样例(sample),每个样例由2个特征(feature)表示。
那么 torch.nn.Linear
就是对输入的样例进行一个线性变换。
我们有了输入数据,接下来看Linear的作用。 Linear首先根据 in_features,out_features
构造初始的权重矩阵
A
A
A(weight) [out_features,in_features]
,和偏置
b
b
b (bias) [out_features]
。
初始的权重矩阵
A
A
A, 经过转置之后就是 [in_features,out_features]
,也就是 2X3
矩阵。
这里我们假设
A
T
A^T
AT 为:
A
T
=
[
0.1
0.2
0.1
0.2
0.3
0.4
]
A^T = \left[ \begin{matrix} 0.1 & 0.2 & 0.1 \\ 0.2 & 0.3 & 0.4\\ \end{matrix} \right]
AT=[0.10.20.20.30.10.4]
初始偏置是 [out_features]
,也就是 1X3
的矩阵。
假设 b为:
b
=
[
1
1
1
]
b = \left[\begin{matrix} 1 & 1 & 1 \end{matrix}\right]
b=[111]
于是输入数据
x
x
x 经过线性变换
y
=
x
A
T
+
b
y=xA^T+b
y=xAT+b 之后得到:
[
1
2
2
3
2
2
1
1
]
∗
[
0.1
0.2
0.1
0.2
0.3
0.4
]
+
[
1
1
1
]
=
[
0.5
0.8
0.9
0.8
1.3
1.4
0.6
1.0
1.0
0.3
0.5
0.5
]
+
[
1
1
1
]
=
[
1.5
1.8
1.9
1.8
2.3
2.4
1.6
2.0
2.0
1.3
1.5
1.5
]
\begin{aligned}& \left[ \begin{matrix} 1&2 \\ 2&3 \\ 2&2 \\ 1&1 \end{matrix}\right] *\left[ \begin{matrix} 0.1 & 0.2 & 0.1 \\ 0.2 & 0.3 & 0.4\\\end{matrix} \right] + \left[\begin{matrix} 1 & 1 & 1 \end{matrix}\right] \\ \\&= \left[\begin{matrix}0.5 & 0.8 & 0.9 \\0.8 & 1.3 & 1.4 \\0.6 & 1.0 & 1.0 \\0.3 & 0.5 & 0.5\end{matrix}\right] + \left[\begin{matrix} 1 & 1 & 1 \end{matrix}\right] \\ \\&= \left[\begin{matrix}1.5 & 1.8 & 1.9 \\1.8 & 2.3 & 2.4 \\1.6 & 2.0 & 2.0 \\1.3 & 1.5 & 1.5\end{matrix}\right]\end{aligned}
⎣⎢⎢⎡12212321⎦⎥⎥⎤∗[0.10.20.20.30.10.4]+[111]=⎣⎢⎢⎡0.50.80.60.30.81.31.00.50.91.41.00.5⎦⎥⎥⎤+[111]=⎣⎢⎢⎡1.51.81.61.31.82.32.01.51.92.42.01.5⎦⎥⎥⎤
写这个推导的目的不是结果,而是想通过第一行矩阵的乘法运算,理解输入数据通过线性变换具体进行了一个什么转换。这一块最好能够自己手动推导一下。
最后总结一下这个Linear Layer的输入输出:
输入数据大小是: [batchsize, in_features]
输出数据大小是: [batchsize, out_features]
batchsize
: 输入样例数
in_features
: 输入样例特征数
out_features
: 输出样例特征数
参考:
PyTorch的nn.Linear()详解 :提到了 batchsize
, 我觉得这个角度很新颖
pytorch系列 —5以 linear_regression为例讲解神经网络实现基本步骤以及解读nn.Linear函数 :从源码角度出发,而且还有完整的训练代码,借鉴度很高
官网参数Linear介绍 :基本的介绍
torch.nn.Linear()函数的理解 :用代码实现了Linear的等价方式,有助于理解其原理