举个例子
import torch
# create a tensor of zeros with shape (3, 4)
t = torch.zeros((3, 4))
# create an index tensor of shape (2,)
index = torch.tensor([0, 2])
# create a tensor to add with shape (2, 4)
src = torch.ones((2, 4))
# add the src tensor to t at the indices specified in the index tensor
t.index_add(0, index, src)
print(t)
输出结果为:
tensor([[1., 1., 1., 1.],
[0., 0., 0., 0.],
[1., 1., 1., 1.]])
解释:
上述代码,使用了 index_add() 方法将 src 张量添加到 t 张量中。具体而言,下面这行代码:
t.index_add(0, index, src)
完成了相加的操作。下面是每个参数的含义:
0:指定执行索引操作的维度。在本例中,我们要将 src 张量添加到由 index 张量指定的 t 张量的行上,因此设置 dim=0。
index:指定要将 src 张量添加到哪些位置的索引。在本例中,我们要将 src 张量添加到 t 的第 0 行和第 2 行,因此设置 index=torch.tensor([0, 2])。
src:要添加到 t 的张量。在本例中,src 是一个全部为 1 的张量,形状为 (2, 4)。
index_add() 方法的工作原理是将 src 的每一行加到由 index 张量指定的 t 张量的相应行上。在本例中,指定了 index=torch.tensor([0, 2]),因此 src 的第一行(全为 1)会被加到 t 的第一行,第二行(同样全为 1)会被加到 t 的第三行。结果是一个张量,其中第一行和第三行为 1,其它地方为 0。