【PyTorch】torch.sum() 函数: 对张量进行求和

torch.sum() 函数详解

torch.sum() 是 PyTorch 中用于 对张量进行求和 的函数,可对所有元素求和,也可在指定维度上求和。


1. 基本语法

torch.sum(input, dim=None, keepdim=False, dtype=None)
参数类型说明
inputTensor要进行求和的张量
diminttuple of int指定在哪个维度上求和(默认对所有元素求和)
keepdimbool是否保留被压缩的维度(默认 False)
dtypetorch.dtype指定输出的数据类型(默认和输入相同)

2. 示例:对所有元素求和

import torch

x = torch.tensor([[1, 2], [3, 4]])
print(torch.sum(x))  # 输出: tensor(10)

3. 示例:按维度求和

沿某一维求和(不保留维度)

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print(torch.sum(x, dim=0))  # → tensor([5, 7, 9])(列方向)
print(torch.sum(x, dim=1))  # → tensor([6, 15])(行方向)

保留维度

print(torch.sum(x, dim=1, keepdim=True))  
# 输出: tensor([[ 6],
#              [15]])

4. 多维度求和

x = torch.arange(24).reshape(2, 3, 4)
print(x)
print(torch.sum(x, dim=(1, 2)))  # 在第 1、2 维上求和

# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],
#
#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])
#tensor([ 66, 210])

5. 指定输出类型(dtype)

x = torch.tensor([1, 2, 3], dtype=torch.int32)
print(torch.sum(x, dtype=torch.float32))  # 输出类型为 float32

6. 常见用途

  • 计算损失总和(如 loss.sum()
  • 用于归一化(softmax、平均损失)
  • 用于 Mask 操作(如 masked.sum() / masked.count_nonzero()
  • 用于求总计、总和统计

7. 注意事项

  • 使用 dim 时求和只发生在指定维度,其它维度不变。
  • 如果你希望计算平均值,可以使用 torch.mean()
  • torch.sum() 不会自动进行梯度截断(用于 backward() 时需注意)。

8. 总结

功能用法
所有元素求和torch.sum(x)
指定维度求和torch.sum(x, dim=1)
保留维度torch.sum(x, dim=1, keepdim=True)
指定类型torch.sum(x, dtype=torch.float32)

这是深度学习中常用的基本操作之一,用于损失计算、掩码处理、归一化等非常频繁

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值