对 torch 中 dim 的总结和理解

pytorch 中,使用到 dim 参数的 api 都是跟集合有关的,比如 max(), min(), mean(), softmax() 等。当指定某个 dim 时,表示使用该维度的所有元素进行集合运算,一个 tensor 的 shape 为 (3, 4, 5),分别对应的 dim 如下所示

dimshape
03
14
25

当使用 max(dim=1) 时,表示使用第二个维度中全部四个元素中的每个元素参与求最大值计算,计算后的 shape 变为 (3,5),因为只从 四个中求得最大的那个作为结果。如果 shape 的长度为 3,则 dim 的取值只能在区间 [-3, 2],否则将报错。

Example

>>> a = torch.randn(3,4,5)
# 求得第二个维度的最大值
>>> torch.max(a,1)
torch.return_types.max(
values=tensor([[0.7700, 0.1390, 0.6952, 1.9428, 0.8477],
        [1.0085, 0.7961, 0.9462, 2.1287, 0.9356],
        [1.1520, 2.1478, 0.8291, 1.0854, 0.7780]]),
indices=tensor([[1, 1, 2, 2, 0],
        [1, 2, 2, 3, 0],
        [0, 1, 3, 3, 3]]))
        
# 第二个维度缩减为只有一个元素,即 (3,1,5),api 将维度为 1 的去掉了
>>> torch.max(a,1).values.shape
torch.Size([3, 5])

# 第三个维度缩减为只有一个元素,即 (3,4,1),api 将维度为 1 的去掉了
>>> torch.max(a,2).values.shape
torch.Size([3, 4])

# 超出 dim 范围,报错
>>> torch.max(a,3).values.shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

总结:
1、dim 是一种集合运算的参数,表示将某个维度的所有元素参与集合运算
2、dim 的取值和 shape 的长度密切相关,dim 的取值为 [-len(shape), len(shape)-1]

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值