torch.exp()
函数详解
torch.exp()
是 PyTorch 中用于对张量中每个元素计算自然指数函数
e
x
e^x
ex 的函数,常用于 实现 softmax、log-normalization、指数增长建模 等场景。
1. 函数原型
torch.exp(input, *, out=None) → Tensor
参数 | 说明 |
---|---|
input | 输入张量 |
out | 可选,输出张量,用于存放结果 |
2. 功能说明
对输入张量中的每个元素
x
x
x 执行:
exp
(
x
)
=
e
x
\text{exp}(x) = e^x
exp(x)=ex
其中
e
≈
2.71828
e \approx 2.71828
e≈2.71828。
3. 示例:基础用法
import torch
x = torch.tensor([0.0, 1.0, 2.0])
y = torch.exp(x)
print(y) # 输出:tensor([1.0000, 2.7183, 7.3891])
4. 示例:多维张量
x = torch.tensor([[0.0, -1.0], [1.0, -2.0]])
print(torch.exp(x))
输出:
tensor([[1.0000, 0.3679],
[2.7183, 0.1353]])
5. 常见应用
5.1 Softmax 的实现
x = torch.tensor([1.0, 2.0, 3.0])
softmax = torch.exp(x) / torch.sum(torch.exp(x))
print(softmax)
5.2 概率建模中的对数概率反变换
log_probs = torch.tensor([-1.0, -2.0])
probs = torch.exp(log_probs)
6. 注意事项
torch.exp()
作用于每个元素,支持任意维度的张量。- 输入为负数时,输出仍为正数(因为 e x > 0 e^x > 0 ex>0 对任意实数 x x x 成立)。
- 输入数值过大时可能导致数值溢出(输出为
inf
),因此在实际中会结合数值稳定性处理(如在 softmax 前减去最大值)。
7. 与 math.exp()
的区别
torch.exp() | math.exp() | |
---|---|---|
输入 | 张量(Tensor) | 单个标量 |
输出 | 张量 | 浮点数 |
支持自动求导 | 是 | 否 |
8. 总结
特性 | 说明 |
---|---|
作用 | 对张量每个元素计算 e x e^x ex |
输出 | 与输入形状相同的张量 |
应用 | softmax、概率模型、正态化、注意力机制等 |
注意 | 大数值可能导致溢出;常配合数值稳定技巧使用 |
这是 深度学习中最基础的数学函数之一,建议与 log()
、softmax()
等函数结合理解和使用。