miguemath指出:torch.Tensor.scatter_()是torch.gather()函数的方向反向操作。两个函数可以看成一对兄弟函数。gather用来解码one hot,scatter_用来编码one hot。
接下来我们用例子看一下:
1. Tensor.scatter(dim, index, src) → Tensor
它是torch.Tensor.scatter_()的错位版本,即:
scatter_(dim, index, src, reduce=None) → Tensor
该函数用来scatter
对于一个3-D的tensor,self会被更新为
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
link:官方文档1
在自然语言处理当中,相应任务的完成离不开vocabulary,vocab将相应的词语token映射为id,我们假设有一组词语,他们的id号是[1,2,0,3], 我们可以用scatter函数得到他们的onehot编码。
>>> index = torch.tensor([1,2,0,3])
>>> index = index.unsqueeze(-1)
>>> index
tensor([[1],[2],[0],[3]])
>>> onehot = torch.zeros(4, 4)
>>> onehot
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> onehot.scatter_(1, index, 1)
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
这里该函数的执行逻辑应该是:
onehot[i][index[i][j]] = 1 for i in range(index.shape[0]) for j in range(index.shape[1])
这里第三个参数可以为1的原因是该函数除了src还有一个输入参数, 在src没有指定时使用(详见文档):
value (float) – the source element(s) to scatter, incase src is not specified
2. Tensor.gather(dim, index) → Tensor
该函数等价于:
torch.gather(input, dim, index, sparse_grad=False, out=None) → Tensor
这里的input就是我们的使用对象self,这两个等价函数的输出如下:
对于一个3-D的tensor,output为
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
link:官方文档2
我们可以利用gather函数将one hot编码中的具体数值取出来(一般很少是从one hot编码中取,而是从模型得到的、关于词表的一个概率向量中取出想要词汇的概率)
>>> prob_vector= torch.rand((3,5))
>>> prob_vector
tensor([[0.5145, 0.7593, 0.0271, 0.2807, 0.0975],
[0.8461, 0.1193, 0.7042, 0.3711, 0.7330],
[0.3504, 0.7746, 0.5122, 0.8491, 0.7501]])
>>> index = torch.tensor([2,1,3])
>>> index = index.unsqueeze(-1)
>>> index
tensor([[2],
[1],
[3]])
>>> target_prob = prob_vector.gather(1,index)
>>> target_prob
tensor([[0.0271],
[0.1193],
[0.8491]])
我们现在有一个概率向量prob_vector,我们想取出与ground truth词汇对应的那个概率,并计算交叉熵。比如第一时间步groud truth词汇的id为2,后面分别为1,3,这时,我们可以通过gather取出相应的概率,就可以很简单地计算交叉熵(将取出的概率取log()即可,因为ground truth词汇的one hot vector在对应位置概率就是1,其他位置全为零),理论上这比整个概率向量(比如[0.5145, 0.7593, 0.0271, 0.2807, 0.0975])直接去与ground truth的one hot编码(比如[0,0,0,1,0])去做交叉熵效率更高。
同样,如果不是在求交叉熵的场景,你也可以利用gather函数从原tensor中取出任何的元素。
3、参考资料
[1]https://pytorch.org/docs/stable/
[2]one hot编码:torch.Tensor.scatter_()
函数用法详解