nn.Linear(in_features, out_features, bias=True)
对输入数据做线性变换:y=Ax+b
如,m = nn.Linear(20, 30)
注意:虽然这里写的是Ax+b,其实是xA+b;(也可以是128个行向量x构成的矩阵,这样128x20与20x30矩阵相乘的结果是128个数据,每个输出数据是30维行向量)。
参数:
- in_features - 每个输入样本的大小
- out_features - 每个输出样本的大小
- bias - 若设置为False,这层不会学习偏置。默认值:True
形状:
- 输入: (N,in_features)
- 输出: (N,out_features)
变量:
- weight -形状为(out_features x in_features)的模块中可学习的权值
- bias -形状为(out_features)的模块中可学习的偏置
例子: