torch.nn.Linear(in_features, out_features, bias=True) 函数是一个线性变换函数:
其中,in_features为输入样本的大小,out_features为输出样本的大小,bias默认为true。如果设置bias = false那么该层将不会学习一个加性偏差。
Linear()函数通常用于设置网络中的全连接层。
import torch
x = torch.randn(8, 3) # 输入样本
fc = torch.nn.Linear(3, 5) # 20为输入样本大小,30为输出样本大小
output = fc(x)
print('fc.weight.shape:\n ', fc.weight.shape, fc.weight)
print('fc.bias.shape:\n', fc.bias.shape)
print('output.shape:\n', output.shape)
ans = torch.mm(x, torch.t(fc.weight)) + fc.bias # 计算结果与fc(x)相同
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
输出结果为:
fc.weight.shape:
torch.Size([5, 3]) Parameter containing:
tensor([[-0.1878, -0.2082, 0.4506],
[ 0.3230, 0.3543, 0.3187],
[-0.0993, -0.0028, -0.1001],
[-0.0479, 0.3248, -0.4867],
[ 0.0574, 0.0451, 0.1525]], requires_grad=True)
fc.bias.shape:
torch.Size([5])
output.shape:
torch.Size([8, 5])
ans.shape:
torch.Size([8, 5])
True
Process finished with exit code 0
首先,nn.linear(3,5)其权重的shape为(5,3),所以x与其相乘时,用torch.t求了nn.linear的转置,这样(83)(35)得到全连接层后的输出维度(85),结果也与fc(x)验证是一致的, torch.mm就是数学上的两个矩阵 相乘。
参考文献:
https://blog.csdn.net/daodaipsrensheng/article/details/117259324