【Torch API】pytorch 中index_add()函数详解

import torch

# create a tensor of zeros with shape (3, 4)
t = torch.zeros((3, 4))

# create an index tensor of shape (2,)
index = torch.tensor([0, 2])

# create a tensor to add with shape (2, 4)
src = torch.ones((2, 4))

# add the src tensor to t at the indices specified in the index tensor
t.index_add(0, index, src)

print(t)

输出结果为:

tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.]])

解释:
上述代码,使用了 index_add() 方法将 src 张量添加到 t 张量中。具体而言,下面这行代码:

t.index_add(0, index, src)

完成了相加的操作。下面是每个参数的含义:

0:指定执行索引操作的维度。在本例中,我们要将 src 张量添加到由 index 张量指定的 t 张量的行上,因此设置 dim=0。

index:指定要将 src 张量添加到哪些位置的索引。在本例中,我们要将 src 张量添加到 t 的第 0 行和第 2 行,因此设置 index=torch.tensor([0, 2])。

src:要添加到 t 的张量。在本例中,src 是一个全部为 1 的张量,形状为 (2, 4)。

index_add() 方法的工作原理是将 src 的每一行加到由 index 张量指定的 t 张量的相应行上。在本例中,指定了 index=torch.tensor([0, 2]),因此 src 的第一行(全为 1)会被加到 t 的第一行,第二行(同样全为 1)会被加到 t 的第三行。结果是一个张量,其中第一行和第三行为 1,其它地方为 0。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值