使用scatter_add_
更新tensor张量中指定index位置的值
例子
import torch
a = torch.zeros((3, 4))
print(a)
"""
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
"""
b = torch.rand((2, 4))
print(b)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.5577],
[0.3469, 0.1025, 0.8185, 0.5085]])
"""
# 将a中第0行和第2行的值修改为b
a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b)
print(a)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.3469, 0.1025, 0.8185, 0.0000]])
"""