一、Scatter函数的理解
这个函数是用一个src的源张量或者标量以及索引来修改另一个张量,常用来做one-hot编码。这个函数主要有三个参数 scatter(dim,index,src)
- dim:沿着哪个维度来进行索引(一会儿举个例子就明白了)
- index:用来进行索引的张量
- src:源张量或者标量
# 这里是三位矩阵的情况下
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
# 在常见的二维矩阵的情况下
#dim=0
self[index[x][y]][y]=src[x][y] # 列相同
#dim=1
self[x][index[x][y]]=src[x][y] # 行相同
二、Scatter函数进行独热编码(one-hot)
import torch
index = torch.arange(5).unsqueeze(1)
'''
index =
tensor([[0],
[1],
[2],
[3],
[4]])
'''
# 这里的scatter加上_是指在原先的torch.zeros(5,5)直接修改,不创建新的副本
one_hot = torch.zeros(5,5).scatter_(1, index, 1)
'''
one_hot =
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
'''
具体的可以参考这篇博客:PyTorch笔记之 scatter() 函数