pytorch标签onehot编码_3分钟理解 pytorch 的 gather 和 scatter

pytorch 的 gather 和 scatter 这两个方法,理解起来是有一定难度的,正因如此,网上有很多博客介绍它们,但往往作者讲得并不清晰。导致读者即使当时看懂了,过段时间又会忘记。经过第N遍谷歌这两个方法的用法,本人终于想到一种类比,可以方便地理解并牢固地记住它们。

类比

假设某校高三有四个班级,各班花名册可以总结为下表:

05e3e1de26fa860aebfe76b19148652f.png
表一 (为了方便理解,假设每班只有两人)

现在他们要报高考志愿,假设总共有五个备选高校:清华、北大、复旦、上交、浙大,序号依次为 0-4。各班的报名状况为:

ffb7ac15d4ed7f0ff2500585390e7767.png
表二

如果我想知道有哪些人报了清华大学,他们分别在哪个班,上面的表格是不够直观的。如果生成一个把大学作为行、把班级作为列的表格,则可以清晰地看到各个大学在各班的报名情况:

72ec1f0a71b09db426260174a1f86853.png
表三

读到这里,恭喜您,已经完成了一次 pytorch 中的 scatter 操作!

scatter

首先摘录 pytorch 的官方文档:

scatter_(dim, index, src, reduce=None) → Tensor
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

不要被它的例子里面复杂的下标索引吓到。我们先从简单的二维情况分析,把上面的例子第一行改成二维的:

self[index[i][j]][j] = src[i][j] # if dim == 0

然后,类比报志愿的问题,这里 src 是表一,index 是表二,而 self 则是表三。i 是学生在该班级内部的编号(所在行),j 是班级编号,注意 i 和 j 都从 0 开始。

那么,src[i][j] 表示的是 j 班的第 i 个同学。根据表一,src[0][0] 就是小红,src[1][2] 是小强。

index[i][j] 表示的是 j 班 i 号同学所报大学的编号。因此,显而易见的是 index 的形状应该和 src 完全一致(严格来讲,其实只需 index 各维度尺寸小于等于 src 各维度对应尺寸,index 在 dim 之外的维度的尺寸,小于等于 self 在dim 之外的尺寸),同时,index 表中最大的数,不应超出可选大学的最大下标,也就是 self 的末行下标。

gather 所做的操作就是,先查出 index[i][j]:找到第 j 班 i 同学所报的大学的编号,该编号就是需要修改的表三的行下标;由于该同学所在的班级为 j 班,因此需要修改的表三的列下标为 j;确定了需要修改的表三的行、列下标,接下来需要确定的是修改的值,在我们的类比中,src[i][j] 是该同学的名字,因此把名字填入 self 中对应位置即可。

下面的代码演示了上面的类比过程:

import torch

# 生成表一,存的是各班学生姓名,用序号代替,序号按列(班级)依次递增
src = torch.arange(8).view(4, 2).transpose(0, 1).type(torch.float32)

print('src=')
print(src)

# tensor([[0, 2, 4, 6],
#         [1, 3, 5, 7]])

# 生成表二,存的是各班各同学所报大学的下标
index = torch.LongTensor([[0, 1, 4, 3],[2, 3, 0, 1]])
print('index=')
print(index)

# tensor([[0, 1, 4, 3],
#         [2, 3, 0, 1]])

#

# 初始化表三,即各校报名情况
tgt = torch.ones(5, 4)*-1
print('tgt=')
print(tgt)

# tensor([[-1., -1., -1., -1.],
#         [-1., -1., -1., -1.],
#         [-1., -1., -1., -1.],
#         [-1., -1., -1., -1.],
#         [-1., -1., -1., -1.]])

# dim=0 表示 index 存储的数字为 tgt 的行号,src 存的是学生名字
# scatter 将学生名字填入 tgt 对应的位置
tgt.scatter_(dim=0, index=index, src=src)
print('tgt=')
print(tgt)

# tensor([[ 0., -1.,  5., -1.],
#         [-1.,  2., -1.,  7.],
#         [ 1., -1., -1., -1.],
#         [-1.,  3., -1.,  6.],
#         [-1., -1.,  4., -1.]])

如果把代码中的 dim=0 改成 dim=1,scatter 的赋值规则变为:

self[i][index[i][j]] = src[i][j] # if dim == 1

回到上面的类比,index 存储的是 self 的列号,即大学的序号,i 是班级序号,j 是学生在该班的序号。最后得到的 self 为表三的转置。

至于 3D tensor,则可以扩展我们的类比,增加一个高中学校的编号:

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0

index[i][j][k] :k 号高中,j 班,i 号学生所报高校编号

src[i][j][k]:k 号高中,j 班,i号学生的名字

self[index[i][j][k]][j][k]:index[i][j][k] 号高校所招的 k 号 高中,j 班学生的名字

小结

scatter_(dim, index, src, reduce=None) → Tensor 中,index 存储的是 self 中 dim 维的下标。scatter 所做的事情是,把 src 中 index 所覆盖的元素(想象把 index 表盖到 src 上),填充到 self 中。填充位置的下标如何确定呢?其中 dim 维下标由该元素对应的 index 确定,而其余维度下标保持和 index 中一致。

gather

摘录 pytorch 的文档:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

上面的类比还可以继续用。这里的 input 为表三,index 为表二,out为表一。即:scatter 是由表一和表二得到表三,而gather则是由表三和表二得到表一。gather 是 scatter 的逆运算。

gather 所做的事就是:知道了各班各同学报的学校下标 (index),也知道各学校在各班的招生情况了(input),反推哪个班有哪个同学(out)。

我们继续上一小节的代码:

input = torch.gather(tgt, 0, index)

print('input=')
print(input)

# tensor([[0., 2., 4., 6.],
#         [1., 3., 5., 7.]])

果然得到了 scatter 中的 src。

scatter 的目标是修改现有矩阵,gather 是根据 index 从现有矩阵提取元素,构成一个新矩阵。新矩阵的形状与 index 一致,index 存储的元素决定了所提取的元素在 input 的 dim-维上的下标,dim-维之外的维度下标与 index 所存元素对应的下标一致。

总结

本文通过类比的方式讲了 scatter 和 gather 的用法,希望能够对读者有所帮助。对于里面可能存在的错漏的地方,望大神在评论区不吝赐教。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值