线性转换:
举例:
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 3, 20) #中间 * 可以添加任意维度
input3 = torch.randn(128, 3, 4, 20)
m = nn.Linear(20, 30)
output1 = m(input1)
output2 = m(input2)
output3 = m(input3)
print(output1.size(), output2.size(), output3.size())
#
torch.Size([128, 30]) torch.Size([128, 3, 30]) torch.Size([128, 3, 4, 30])
中间 * 可以是任意维度,原理解释:
input2 = torch.randn(128, 3, 20)
m = nn.Linear(20, 30)
output2 = m(input2)
input3 = input2.reshape(128 * 3, 20)
output3 = m(input3)
print(output3 == output2.reshape(128 * 3, -1))
#
tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], dtype=torch.uint8)
可见,将所有前面的维度相乘变为了二维矩阵,nn.Linear() 线性变换,也就是全连接层的变换。