torch.sum()
函数详解
torch.sum()
是 PyTorch 中用于 对张量进行求和 的函数,可对所有元素求和,也可在指定维度上求和。
1. 基本语法
torch.sum(input, dim=None, keepdim=False, dtype=None)
参数 | 类型 | 说明 |
---|---|---|
input | Tensor | 要进行求和的张量 |
dim | int 或 tuple of int | 指定在哪个维度上求和(默认对所有元素求和) |
keepdim | bool | 是否保留被压缩的维度(默认 False) |
dtype | torch.dtype | 指定输出的数据类型(默认和输入相同) |
2. 示例:对所有元素求和
import torch
x = torch.tensor([[1, 2], [3, 4]])
print(torch.sum(x)) # 输出: tensor(10)
3. 示例:按维度求和
沿某一维求和(不保留维度)
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(torch.sum(x, dim=0)) # → tensor([5, 7, 9])(列方向)
print(torch.sum(x, dim=1)) # → tensor([6, 15])(行方向)
保留维度
print(torch.sum(x, dim=1, keepdim=True))
# 输出: tensor([[ 6],
# [15]])
4. 多维度求和
x = torch.arange(24).reshape(2, 3, 4)
print(x)
print(torch.sum(x, dim=(1, 2))) # 在第 1、2 维上求和
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
#tensor([ 66, 210])
5. 指定输出类型(dtype)
x = torch.tensor([1, 2, 3], dtype=torch.int32)
print(torch.sum(x, dtype=torch.float32)) # 输出类型为 float32
6. 常见用途
- 计算损失总和(如
loss.sum()
) - 用于归一化(softmax、平均损失)
- 用于 Mask 操作(如
masked.sum() / masked.count_nonzero()
) - 用于求总计、总和统计
7. 注意事项
- 使用
dim
时求和只发生在指定维度,其它维度不变。 - 如果你希望计算平均值,可以使用
torch.mean()
。 torch.sum()
不会自动进行梯度截断(用于backward()
时需注意)。
8. 总结
功能 | 用法 |
---|---|
所有元素求和 | torch.sum(x) |
指定维度求和 | torch.sum(x, dim=1) |
保留维度 | torch.sum(x, dim=1, keepdim=True) |
指定类型 | torch.sum(x, dtype=torch.float32) |
这是深度学习中常用的基本操作之一,用于损失计算、掩码处理、归一化等非常频繁。