搜出来的博客很多都写得不清楚,所以我在看了一会了后才发现了这个函数的规律。
其实很简单
scatter_(dim,index,src)
dim=0就是每列,dim=1就是每行
index就是要填充的位置
src可以是一个tensor也可以是一个标量
>>> x = torch.rand(2, 5)
>>> x
tensor([[0.3469, 0.8207, 0.6422, 0.4681, 0.2340],
[0.1284, 0.0996, 0.0661, 0.5112, 0.2919]])
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[0.3469, 0.0996, 0.0661, 0.4681, 0.2340],
[0.0000, 0.8207, 0.0000, 0.5112, 0.0000],
[0.1284, 0.0000, 0.6422, 0.0000, 0.2919]])
对于index来说,[0][0]=0 , [0][1]=2就是在torch.zeros(3,5)的第1列里的第0个和第2个位置填充x的值0.3469和0.1284。[2][0]=2 , [2][1]=0就是在torch.zeros(3,5)的第3列里的第2个和第0个位置填充x的值0.6422和0.0661。
>>> torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1)
tensor([[0., 0., 1., 0.],
[0., 0., 0., 1.]])
这里dim=1,就是在(2,4)的第一行的第2个位置和第二行的第三个位置填充1