PyTorch scatter_()和gather()解析

torch.Tensor.scatter_(dim, index, src)

scatter_(dim, index, src) → Tensor

官方解释:
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

不用看了,看了也不懂。。。

scatter_()操作可以这么理解:

假设x是一个tensorx.scatter_(dim, index, src)表示对x中的元素按某种规则进行修改(把某些来自参数src的元素放置进来替换掉原来位置的元素,没被替换的则保持原样), 而修改的依据则由参数dim和参数index决定。

scatter_()的三个输入参数:

  • src – 需要放置的元素就在这里面选取,可以是张量或标量
  • dim – 沿着哪个维度进行选取(index操作)
  • index – 需要选取的元素对应的索引

还是略抽象,看个简单的例子吧:

x = torch.zeros(3,3)
>>> x
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
src = torch.ones(2,3)
>>> src
tensor([[1., 1., 1.],
        [1., 1., 1.]])
        
index = torch.tensor([[0,0,1],[1,0,0]])
>>> index
tensor([[0, 0, 1],
        [1, 0, 0]])

dim先取0,操作结果:

x = x.scatter_(0, index, src)	# dim=0
>>> x
tensor([[1., 1., 1.],
        [1., 0., 1.],
        [0., 0., 0.]])

以上操作可以简化为:

torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), torch.ones(2,3))

可以看到,scatter_()操作将src中的某些元素 1 放置到了x中原本是 0 的位置,而放置的依据是dimindex两个参数,具体分析如下:

  • x是二维的,可以用x[i][j]来表示x中的相应位置的元素,比如:
x = torch.zeros(3,3)
>>> x[0][0]
tensor(0.)
  • dim = 0代表着沿着第0个维度进行操作,也就是对x[i][j]中的i进行相应的操作,替换规则如下:
    x [ i n d e x [ i ] [ j ] ] [ j ] = s r c [ i ] [ j ] (1) x[index[i][j]][j] = src[i][j] \tag{1} x[index[i][j]][j]=src[i][j](1)

    其中i为0、1,j为 0、1、2(这儿i,j的取值不要超过索引范围就ok啦)。

    i=0, j=0,1,2
    x[index[0][0]][0]index[0][0] = 0x[0][0] = src[0][0] = 1
    x[index[0][1]][1]index[0][1] = 0x[0][1] = src[0][1] = 1
    x[index[0][2]][2]index[0][2] = 1x[1][2] = src[0][2] = 1

    i=1, j=0,1,2
    x[index[1][0]][0]index[1][0] = 1x[1][0] = src[1][0] = 1
    x[index[1][1]][1]index[1][1] = 0x[0][1] = src[1][1] = 1
    x[index[1][2]][2]index[1][2] = 0x[0][2] = src[1][2] = 1

    所以上述操作就是将x中的某些元素修改为src中对应位置的元素,而对应关系是由dimindex来决定 的。注意上述操作中,i=0的第二项和i=1的第二项实际上是都是对x[0][1]进行修改,只不过修改之后的值恰 好是相同的。

  • 那么dim = 1就代表着沿着第1个维度进行操作,也就是对x[i][j]中的j进行相应的操作,规则如下:
    x [ i ] [ i n d e x [ i ] [ j ] ] = s r c [ i ] [ j ] (2) x[i][index[i][j]] = src[i][j] \tag{2} x[i][index[i][j]]=src[i][j](2)
    具体就不展开细说了。

下面还是来看一下官方例子吧:

src = torch.rand(2, 5)
>>> src
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
x = torch.zeros(3, 5)

x.scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src)

tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

随便测试一下:

x[index[0][0]][0]index[0][0] = 0x[0][0] = src[0][0] = 0.3992
x[index[0][2]][2]index[0][2] = 2x[2][2] = src[0][2] = 0.9044
是不是都对应上啦 (✿◡‿◡)

除了二维的,三维的也是同样的道理,官网也给出的三维的对照规则:

if dim = 0, s e l f [ i n d e x [ i ] [ j ] [ k ] ] [ j ] [ k ] = s r c [ i ] [ j ] [ k ] self[index[i][j][k]][j][k] = src[i][j][k] self[index[i][j][k]][j][k]=src[i][j][k]
if dim = 1, s e l f [ i ] [ i n d e x [ i ] [ j ] [ k ] ] [ k ] = s r c [ i ] [ j ] [ k ] self[i][index[i][j][k]][k] = src[i][j][k] self[i][index[i][j][k]][k]=src[i][j][k]
if dim = 2, s e l f [ i ] [ j ] [ i n d e x [ i ] [ j ] [ k ] ] = s r c [ i ] [ j ] [ k ] self[i][j][index[i][j][k]] = src[i][j][k] self[i][j][index[i][j][k]]=src[i][j][k]

好了,前面我们说到了src除了是张量之外,还可以是标量,比如之前的例子:

torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), torch.ones(2,3))

修改为:

torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), 1)

也是同样的效果。
scatter_()操作常常用来对标签进行 one-hot 编码,这也是我最开始碰到这个函数的地方。one-hot 编码使用的src一般就为标量,也就是利用一个标量对张量进行修改。

我们在计算 loss(比如Cross entropy loss)的时候,预测矩阵的形状一般为 (batch_size, num_classes),而标签的形状一般为 (batch_size, 1),需要将标签形状修改为(batch_size, num_classes)以方便计算,其实就是将标量扩展成张量。一般这么进行:

num_classes = 5
batch_size = 4

label = torch.randint(0, num_classes, size=(batch_size,1))
>>> label
tensor([[0],
        [1],
        [4],
        [3]])

原始4个样本的标签分别是0、1、4、3,代表样本分别属于第一、二、五、四个类别,此处回顾一下交叉熵的计算方式:

l o s s = − ∑ i y i ln ⁡ y ^ i (3) loss=-\sum_{i} y_{i} \ln \hat{y}_{i} \tag{3} loss=iyilny^i(3)

次数的 y i y_{i} yi y ^ i \hat{y}_{i} y^i 具有相同的形状。比如上述的label中第一个样本的标签是0,把它转为长为num_classes张量即为:

# 第一个类别处为1, 其余全为0
tensor([1, 0, 0, 0, 0])

one-hot 编码其实就是这样的操作,所属标签类别处为1,其余为0,用scatter_()函数实现如下:

# dim = 1
one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
>>> one_hot
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.]])

注意到编码后的one-hot张量对应每个样本的标签处都实现了相应的转换。

其实在PyTorch 1.1版本之后,提供了更简单的 one-hot 编码方式:
https://stackoverflow.com/questions/44461772/creating-one-hot-vector-from-indices-given-as-a-tensor

one_hot = torch.nn.functional.one_hot(label, 5)
>>> one_hot
tensor([[[1, 0, 0, 0, 0]],

        [[0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 1]],

        [[0, 0, 0, 1, 0]]])

不过要注意生成的张量形状略有不同(此处为4x1x5,之前为4x5)。如果需要变换为4x5:

one_hot.squeeze(1)

即可!!!

torch.gather(input, dim, index, out=None, sparse_grad=False)

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor

理解了上面的scatter_(),再来理解gather()就很简单了,因为他们两个互为逆向操作,对于torch.gather(),三维操作规则如下:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

是不是看起来就很熟悉了,只是satter_()操作的三维规则逆向过来。同理二维的操作规则:

out[i][j] = input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1

例子:

input = torch.tensor([[1,2],[3,4]])
output = torch.gather(input, 0, torch.tensor([[0,0],[1,0]]))
>>> output
tensor([[1, 2],
        [3, 2]])

取两个测试一下:

output[0][0] = input[index[0][0]][0] = input[0][0] = 1
output[1][1] = input[index[1][1]][1] = input[0][1] = 2

完。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值