理解torch.scatter_()
官方文档
scatter_
(dim, index, src): 将src中所有的值分散到self
中,填法是按照index
中所指示的索引来填入。
dim
用来指定index
进行映射的维度,其他维度则保持不变。
Note: src
可以是一个scalar。在这种情况下,该函数的操作是根据index
来散布单个值。
当dim=0
dim=0,意味着在src
按照index
行索引的指示来进行散射,换言之,src
的j
列按照index
的j
列中的值散射到self
的j
列中。(表述还是很绕,看例子吧)
以下是官方的例子:
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
因为dim=0,所以是列映射到列,散射操作可以按列依次进行。
第一列:
第二列:
直到最后一列:
当dim = 1
dim=1,意味着在src
按照index
列索引的指示来进行散射,换言之,src
的i
行按照index
的i
行中的值散射到self
的i
列中。
>>> src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
>>> input_tensor = torch.zeros(3, 5)
>>> index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
>>> dim = 1
>>> input_tensor.scatter_(dim, index_tensor, src)
tensor([[ 2., 4., 3., 1., 5.],
[ 7., 10., 6., 9., 0.],
[ 0., 0., 0., 0., 0.]])
散射操作前:
更新第一行:
更新第二行, 可以看到index
中出现重复的映射索引值1
,因此后一个会把前一个覆盖:
8和10都是映射到col1
,可以看到10把8给覆盖了。
当src是scalar
>>> input_tensor = torch.from_numpy(np.arange(1, 16)).float().view(3, 5) # dim is 2
>>> # unsqueeze to have dim = 2
>>> index_tensor = torch.tensor([4, 0, 1]).unsqueeze(1)
>>> src = 0
>>> dim = 1
>>> input_tensor.scatter_(dim, index_tensor, src)
tensor([[ 1., 2., 3., 4., 0.],
[ 0., 7., 8., 9., 10.],
[11., 0., 13., 14., 15.]])
Note:
-
index
的维度要和输入张量的维度保持一致。同时index
要在相同维度上的尺度不能大于输入张量。 -
当
src
是标量时,我们实际上使用的是广播版本,其形状与index
张量相同。
代码实操
该函数最常用的场景是把标量的标签转换为one-hot编码
batch_size = 4
class_num = 5
labels = torch.tensor([4, 0, 1, 2]).unsqueeze(1)
one_hot = torch.zeros(batch_size, class_num)
dim=1; index_tensor = labels; src=1
one_hot.scatter_(dim, index_tensor, src)
print(one_hot)
> tensor([[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
References:
-
https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_