pytorch-scatter解释
举个二维的例子,用Python重写一下就很明白了,非常简单。注意:scatter返回新张量,scatter_在原张量上修改。torch中的默认都是这样,例如torch.tensor.data.zero_()就是直接在原张量上操作。
调用官方接口
src = torch.arange(1,7).reshape(2,-1).type(torch.FloatTensor)
target = torch.zeros((2, 3)).type(torch.FloatTensor)
index = torch.tensor([[1, 2, 0], [1, 2, 1]])
target_官方 = target.scatter_(-1, index, src)
'''
tensor([[3., 1., 2.],
[0., 6., 5.]])
'''
重写:
src = torch.arange(1,7).reshape(2,-1).type(torch.FloatTensor)
target = torch.zeros((2, 3)).type(torch.FloatTensor)
index = torch.tensor([[1, 2, 0], [1, 2, 1]])
for i in range(2):
for j in range(3):
try:# 如果i, j超出index范围,那么不做改变。
c = index[i][j]
# 如果c可以找到,那么c一定要在target索引范围内,并且i,j也一定要在src范围内,这要是为啥src要比target大的原因。
target[i][c] = src[i][j]
'''
tensor([[3., 1., 2.],
[0., 6., 5.]])
本质就是dim=1, i[index[i][j]] for all j
'''
用这个特性,将labels改成one-hot编码
# 5个样本,6分类:0-5,
labels = torch.tensor([0, 2, 5, 4, 1])
src = torch.ones((5, 6)) 或者 src = 1也可以
target = torch.zeros((5,6))
# 将lablels作为index使用,因此需要将labels改为二维,具体因为啥逻辑实际写一遍重写部分, 把(i,c)配对打出来。
labels = labels.unsqueeze(-1)
one_hot = target.scatter(-1, labels, src) # 或者 one_hot = target.scatter(-1, labels, 1)
'''
tensor([[1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 0., 0.]])
'''