参考链接PyTorch的nn.Linear()详解 - douzujun - 博客园 (cnblogs.com)
这里演示了二维张量的全连接 :
其实还可以输入三维张量,演示如下:
from torch import nn
import torch
# in_features由输入张量的形状决定,out_features则决定了输出张量的形状
linear = nn.Linear(in_features=64 * 3, out_features=5)
# 10个 大小为7*64*3, 3个channel 的张量
a = torch.rand(10, 3, 7, 64 * 3)
print(a.shape) # torch.Size([10, 3, 7, 192])
print(linear.weight.shape) # torch.Size([5, 192])
b = linear(a)
print(b.shape) # torch.Size([10, 3, 7, 5])