pytorch中torch.Tensor.scatter用法

今天读代码,看到有人用了torch.Tensor.scatter这个函数。这个函数之前我也看到过,但是没有搞明白是干啥用的,今天我搞明白了。

首先看一下官方文档的定义:(可跳过)

==============================pytorch docs分割线====================================

scatter_(dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in gather().

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

Parameters

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity

  • src (Tensor) – the source element(s) to scatter, incase value is not specified

  • value (float) – the source element(s) to scatter, incase src is not specified

  • ===============================================================================

通过自己的实验,我明白了这个函数的目的,下面讲解一下:

首先看一下这个函数的接口,需要三个输入:1)维度dim 2)索引数组index 3)原数组src,为了方便理解,我们后文把src换成input表示。最终的输出是新的output数组。

下面依次介绍:

1)维度dim:整数,可以是0,1,2,3...

2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置

3)原数组input:也是一个tensor,其中的数据类型任意

先说一下这个函数是干嘛的,在我看来,这个scatter函数就是把input数组中的数据进行重新分配。index中表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充0。

比如说下面这段代码:

import torch

input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

 

运行结果如下:

tensor([[-0.2558, -1.8930, -0.7831,  0.6100],
        [ 0.3246,  2.1289,  0.5887,  1.5588]])
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],
        [ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])

下面,我详细说一下为什么会是这样的结果。

前面说了,scatter是input数组,根据index数组,对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。

input:

tensor([[-0.2558, -1.8930, -0.78310.6100],
             [ 0.3246,  2.1289,  0.5887,  1.5588]])

index:

index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output:

tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],
             [ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])

首先,对input[0][0]进行重分配。符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。

input[0][0] -> output[0][index[0][0]] = output[0][3]

数据位置发生的变化都是在第1维上,第0维不变。

input[0][1] -> output[0][index[0][1]] = output[0][1]

input[0][2] -> output[0][index[0][2]] = output[0][2]

input[0][3] -> output[0][index[0][3]] = output[0][0]

需要注意的是,

为了方便理解,我们是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0。

一般scatter用于生成onehot向量,如下所示:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

输出结果是:

tensor([[0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [1., 0., 0., 0.],
             [0., 0., 0., 1.]])

如果input是一个数字的话,代表这用于分配到output的数字是多少。

 

 

 

  • 60
    点赞
  • 61
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值