import torch
'''
A.scatter_(dim, index, B) # 基本用法, tensor A 被就地scatter到 tensor B
源tensor的每个元素,都按照 index 被scatter(可以理解为填充)到目标tensor中。
B 为源tensor,A为目标tensor。
dim 和 index:两个参数是配套的;
index和源tensor维度一致(可以为空,代表不改变目标tensor),对于n-D tensor,dim可以为0~N-1。
index为几,就把对应位置的元素放入目标tensor的第几行;
reduce参数:
默认是None,直接覆盖
multiply: src元素 * target元素
add:src元素 + target元素
对于全0矩阵,None和add效果一致;对于全1矩阵,None和multiply效果一致。
'''
a = torch.randn(2, 3) # 源tensor
print(a)
b = torch.zeros(2, 3).scatter_(dim=1, index=torch.tensor([[1, 2], [0, 1]]), src=a)
print(b)
'''
上例结果:
tensor([[-0.5172, 0.0915, -1.9869],
[-0.1619, 1.3641, 0.1983]])
tensor([[ 0.0000, -0.5172, 0.0915],
[-0.1619, 1.3641, 0.0000]])
'''
c = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a)
print(c)
'''
上例结果:
a: tensor([[ 0.2210, -1.2891, 1.1144],
[-0.3524, 0.1736, 2.0364]])
c: tensor([[-0.3524, -1.2891, 0.0000],
[ 0.2210, 0.1736, 0.0000]])
'''
d = torch.ones(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="multiply")
# print(d)
'''
tensor([[-8.7126e-01, 1.3744e+00, -5.1777e-04],
[-1.6414e+00, 1.1157e+00, -1.9982e+00]])
tensor([[-1.6414, 1.3744, 1.0000],
[-0.8713, 1.1157, 1.0000]])
'''
e = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="add")
print(e)
'''
tensor([[-0.7597, 1.3491, -0.2875],
[ 1.5010, -1.6951, 2.6675]])
tensor([[ 1.5010, 1.3491, 0.0000],
[-0.7597, -1.6951, 0.0000]])
'''
参考:
https://zhuanlan.zhihu.com/p/339043454
本文详细介绍了PyTorch中的scatter_方法及其使用方式,包括如何通过指定维度和索引来将一个张量的值散布到另一个张量中,并解释了不同reduce参数的效果。
5224

被折叠的 条评论
为什么被折叠?



