torch.nn.Linear 是 pytorch 的线性变换层,定义如下:
Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)
全连接层 Fully Connect 一般就就用这个函数来实现。因此在潜意识里,变换的输入张量的 shape 为 (batchsize, in_features),输出张量的 shape 为 (batchsize, out_features)。
当然这是常用的方式,但是 Linear 的输入张量的维度其实并不需要必须为上述的二维,多维也是完全可以的,Linear 仅是对输入的最后一维做线性变换,不影响其他维。
可以看下官网的解释:Linear — PyTorch 1.11.0 documentation
一个例子如下:
import torch
input = torch.randn(30, 20, 10) # [30, 20, 10]
linear = torch.nn.Linear(10, 15) # (*, 10) --> (*, 15)
output = linear(input)
print(output.size()) # 输出 [30, 20, 15]