PyTorch 中 nn.Linear 层特性总结笔记

一、核心特性概述

在 PyTorch 中,nn.Linear(in_features, out_features) 线性层总是作用于输入张量的最后一个维度,其他维度被视为 “任意形状” 的 batch 维度。无论输入张量是几维,线性层都会提取最后一个维度进行处理。

  • 处理逻辑:对输入张量最后一个维度执行线性变换,输出张量除最后一个维度变为 out_features 外,其他维度保持不变 。

  • 维度要求:输入张量最后一个维度需等于 in_features ,输出张量最后一个维度为 out_features

二、示例说明

1. 三维输入(常见于 NLP 序列数据)

x = torch.randn(2, 4, 512)  # \[batch\_size, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape)  # 输出: \[2, 4, 512]

线性层仅对最后维度 512 进行变换,[2, 4] 作为批量相关维度保留不变。

2. 四维输入(如图像 Transformer 数据)

x = torch.randn(2, 8, 16, 512)  # \[batch, head, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape)  # 输出: \[2, 8, 16, 512]

同样,线性层仅处理最后维度 512 ,其余维度原样输出。

3. 二维输入(传统 NLP 词向量)

x = torch.randn(32, 512)  # \[batch\_size, embed\_dim]
output = linear(x)
print(output.shape)  # 输出: \[32, 512]

无论输入维度简单或复杂,线性层均基于最后维度执行变换。

三、数学原理

对于形状为 [..., in_features] 的输入张量 x ,经 nn.Linear(in_features, out_features) 变换后,输出张量形状变为 [..., out_features] 。其数学运算公式为:

y=x⋅WT+by = x \cdot W^T + by=xWT+b

其中:

  • W∈Rout_features×in_featuresW \in \mathbb{R}^{out\_features \times in\_features}WRout_features×in_features ,是可学习权重矩阵;

  • b∈Rout_featuresb \in \mathbb{R}^{out\_features}bRout_features ,是可学习偏置向量。

四、应用场景举例(以 Transformer 为例)

输入张量形状 最后一个维度 是否匹配 Linear(embed_dim, embed_dim)应用场景
[2, 4, 512]512✅ 是,可直接使用 常规 Transformer 层间变换
[2, 8, 4, 64]64✅ 是,适用于多头注意力中每个 head 的输出 多头注意力后处理
[10, 20]20✅ 可用于任何需要线性变换的地方 简单词向量变换

五、总结

nn.Linear(embed_dim, embed_dim) 仅关注输入张量最后一个维度是否等于 in_features ,并基于此进行线性变换,其他维度保持不变。该特性赋予了线性层在 RNN、Transformer、CNN 等多种网络结构中灵活应用的能力 。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值