torch.nn.functional.log_softmax
torch.nn.functional.log_softmax
是 PyTorch 提供的用于计算 log(softmax)
的函数,通常用于 多分类任务 和 计算交叉熵损失,可以提高数值稳定性并防止数值溢出。
1. log_softmax
的数学公式
对于 输入张量
X
X
X,softmax 计算如下:
softmax
(
X
i
)
=
e
X
i
∑
j
e
X
j
\text{softmax}(X_i) = \frac{e^{X_i}}{\sum_{j} e^{X_j}}
softmax(Xi)=∑jeXjeXi
然后取对数:
log
softmax
(
X
i
)
=
X
i
−
log
∑
j
e
X
j
\log \text{softmax}(X_i) = X_i - \log \sum_{j} e^{X_j}
logsoftmax(Xi)=Xi−logj∑eXj
为什么使用 log_softmax
而不是 softmax + log
?
- 防止数值溢出:
softmax(X)
可能会导致指数运算 溢出(特别是X
取值较大时)。log_softmax(X)
计算时,先 进行数值归一化,不会导致溢出。
- 更高效:
log_softmax
计算速度更快,因为它可以与nll_loss
直接配合使用。
2. torch.nn.functional.log_softmax
语法
torch.nn.functional.log_softmax(input, dim)
参数 | 说明 |
---|---|
input | 输入张量(通常是模型 logits 输出) |
dim | 计算 softmax 的维度(通常 dim=1 ,沿类别维度计算) |
3. 示例:计算 log_softmax
import torch
import torch.nn.functional as F
# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])
# 计算 log_softmax(dim=1 表示沿类别维度)
log_probs = F.log_softmax(logits, dim=1)
print(log_probs)
输出
tensor([[-0.4170, -1.4170, -2.3170],
[-2.0076, -0.5076, -1.0076]])
解析
logits
未归一化,直接输入。log_softmax(logits, dim=1)
对每一行计算 log(softmax)。- 数值稳定性更高,避免
softmax
溢出问题。
4. log_softmax
与 softmax + log
的区别
logits = torch.tensor([[2.0, 1.0, 0.1]])
# 方式 1:直接使用 log_softmax
log_probs1 = F.log_softmax(logits, dim=1)
# 方式 2:先 softmax 再 log
softmax_probs = torch.softmax(logits, dim=1)
log_probs2 = torch.log(softmax_probs)
print(log_probs1)
print(log_probs2)
两者输出相同,但 log_softmax
更稳定!
tensor([[-0.4170, -1.4170, -2.3170]])
tensor([[-0.4170, -1.4170, -2.3170]])
F.log_softmax
比torch.softmax + torch.log
更稳定,计算更高效。
5. log_softmax
在交叉熵损失中的作用
PyTorch F.cross_entropy
内部已经包含 log_softmax
:
import torch.nn.functional as F
# 假设 batch_size=2,类别数=3
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 1.5]])
targets = torch.tensor([0, 2]) # 真实类别索引
# 交叉熵损失(内部自动使用 log_softmax)
loss = F.cross_entropy(logits, targets)
print(loss)
解析
F.cross_entropy(logits, targets)
内部使用log_softmax + nll_loss
。- 直接传入
logits
,无需手动计算softmax
。
6. log_softmax
与 nll_loss
如果已经计算了 log_softmax
,可以直接使用 nll_loss
:
log_probs = F.log_softmax(logits, dim=1)
# 使用 nll_loss 计算损失
loss = F.nll_loss(log_probs, targets)
print(loss)
nll_loss
只计算 log_softmax
结果与目标的匹配度,与 F.cross_entropy
结果相同。
7. 适用场景
- 分类任务(如 CNN, NLP):
F.cross_entropy
内部已使用log_softmax
,无需手动计算。
- 强化学习(RL):
- 计算 策略梯度损失 时,经常需要
log_softmax
。
- 计算 策略梯度损失 时,经常需要
- 变分自编码器(VAE):
- 计算 ELBO 损失 需要
log_softmax
。
- 计算 ELBO 损失 需要
8. 结论
torch.nn.functional.log_softmax
用于计算log(softmax)
,提高数值稳定性。- 比
softmax + log
更稳定,更高效。 - 在 分类任务 中,
cross_entropy
已包含log_softmax
,无需额外计算。 - 在强化学习、VAE 等任务中,
log_softmax
也是常用的概率计算方法。