'''
nn.Linear(in_features,out_features,bias=True)
用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
输入:[batch_size, in_features]
输出:[batch_size,output_size]
'''
import torch as t
from torch import nn
# in_features由输入张量的形状决定,out_features则决定了输出张量的形状
connected_layer = nn.Linear(in_features = 64*64*3, out_features = 1)
# 假定输入的图像形状为[64,64,3]
input = t.randn(1,64,64,3)
# 将四维张量转换为二维张量之后,才能作为全连接层的输入
input = input.view(1,64*64*3)
print(input.shape) # torch.Size([1, 12288])
output = connected_layer(input) # 调用全连接层
print(output.shape) # torch.Size([1, 1])
参考链接:
PyTorch的nn.Linear()详解