torch.nn.Linear
(全连接层)
torch.nn.Linear
是 PyTorch 中用于实现全连接层(Fully Connected Layer,FC 层) 的类,通常用于 MLP(多层感知机)、CNN 的分类层、Transformer 等任务。
1. torch.nn.Linear
语法
torch.nn.Linear(in_features, out_features, bias=True)
参数 | 说明 |
---|---|
in_features | 输入特征的维度 |
out_features | 输出特征的维度 |
bias | 是否包含偏置项(默认为 True ) |
- 计算公式:
y = x W T + b y = xW^T + b y=xWT+b
其中:-
x
x
x 是输入张量,形状
(..., in_features)
-
W
W
W 是权重矩阵,形状
(out_features, in_features)
-
b
b
b 是可选的偏置项,形状
(out_features)
-
x
x
x 是输入张量,形状
2. 示例:定义 Linear
层
import torch
import torch.nn as nn
# 定义全连接层:输入 4 维,输出 3 维
linear = nn.Linear(in_features=4, out_features=3)
# 查看权重和偏置
print("Weight shape:", linear.weight.shape) # torch.Size([3, 4])
print("Bias shape:", linear.bias.shape) # torch.Size([3])
解析
in_features=4
,out_features=3
,即 输入 4 维,输出 3 维。linear.weight.shape
为(3, 4)
,linear.bias.shape
为(3,)
。
3. Linear
层前向计算
# 创建输入张量(batch_size=2, features=4)
x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
# 计算输出
output = linear(x)
print(output.shape) # torch.Size([2, 3])
print(output)
解析
- 输入
x
形状(2, 4)
,表示 batch_size=2,每个样本有 4 个特征。 - 经过
Linear(4, 3)
层后,输出变为(2, 3)
。
4. Linear
在神经网络中的应用
4.1 多层感知机(MLP)
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(10, 32) # 输入 10 维,隐藏层 32 维
self.fc2 = nn.Linear(32, 5) # 隐藏层 32 维,输出 5 维
def forward(self, x):
x = torch.relu(self.fc1(x)) # 第一层 ReLU
x = self.fc2(x) # 第二层输出
return x
# 创建模型
model = MLP()
# 生成输入数据(batch_size=3, features=10)
x = torch.randn(3, 10)
# 前向传播
output = model(x)
print(output.shape) # torch.Size([3, 5])
解析
fc1 = Linear(10, 32)
:输入 10 维,映射到 32 维。fc2 = Linear(32, 5)
:映射到 5 维(例如 5 类分类任务)。ReLU
作为非线性激活函数。
4.2 CNN 结合 Linear
进行分类
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3) # 3 通道输入
self.fc1 = nn.Linear(16 * 30 * 30, 10) # 展平后输入 FC 层
def forward(self, x):
x = torch.relu(self.conv1(x)) # 经过卷积
x = x.view(x.size(0), -1) # 展平
x = self.fc1(x) # 通过全连接层
return x
model = CNN()
解析
Conv2d
提取图像特征,Linear
进行分类。
5. bias=False
的用法
linear_no_bias = nn.Linear(4, 3, bias=False)
bias=False
去掉偏置项,仅使用权重W
进行线性变换。- 适用于 Batch Normalization 或某些特殊任务。
6. 权重初始化
Linear
层的权重默认采用 Kaiming 均匀初始化,但可以手动修改:
nn.init.xavier_uniform_(linear.weight) # Xavier 初始化
nn.init.constant_(linear.bias, 0) # 偏置初始化为 0
7. nn.Linear
vs F.linear
方式 | 调用方式 | 是否可训练 |
---|---|---|
nn.Linear | 作为网络层定义 | 是(可训练) |
F.linear | 直接计算 y = xW^T + b | 否(仅计算,不创建参数) |
示例
import torch.nn.functional as F
# 直接计算线性变换(不创建参数)
x = torch.randn(2, 4)
weight = torch.randn(3, 4)
bias = torch.randn(3)
output = F.linear(x, weight, bias) # 计算 y = xW^T + b
nn.Linear
创建可训练参数,适用于nn.Module
。F.linear
仅用于临时计算,不用于模型训练。
8. 适用场景
- MLP(全连接神经网络)
- CNN 分类层
- Transformer(注意力机制)
- 自动编码器(AutoEncoder)
9. 结论
torch.nn.Linear(in_features, out_features, bias=True)
用于线性变换。- 适用于神经网络分类任务、特征映射等。
- 可以与
ReLU
、BatchNorm
等结合。 - 推荐
nn.Linear
用于模型定义,F.linear
用于计算测试。