PyTorch中scatter()和scatter_()函数的作用

本文讲的是我对PyTorch中scatter()函数的理解。

原创,转载请标明来源。

 

一言以蔽之:修改tensor中的指定位置的值。

函数

scatter(dim, index, src) 

  • dim: 索引的维度。按照i, j, k, ...的哪个方向进行索引
  • index: 索引。可以是一个tensor,存储需要改的元素的位置的tensor
  • src: 用src中的值来修改。可以是tensor;可以是一个数字,用同样的数字写入tensor

scatter() 和 scatter_() 函数功能相同:只不过带下划线的函数,通常是直接修改原来的tensor

原理

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

函数的具体实现,如上述代码框所示:使用src中的值,修改self中位置为index[i][j][k]的值。

举例

# 这是src
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

# index是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]
# self就是下面的torch.zeros(3, 5)
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

# 这是结果
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

在这个例子中,dim=0,是按照i的方向修改torch.zeros(3, 5)的。可以看出,index实际上表达了一种映射关系:同样都是在第j列上,将src在该列的值 根据index在该列的指示 映射到self的这一列上。

比如:src在第0列的0.01940和0.2078,被放到self的第0列上,但不是完全一样的放过来,而是经过index变化了上下位置。其他列同理。

在简单RNN中的应用

该应用代码如下 [1]:

def one_hot(x, n_class, dtype=torch.float32): 
    # X shape: (batch), output shape: (batch, n_class)
    x = x.long()
    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
    res.scatter_(1, x.view(-1, 1), 1)
    return res

x = torch.tensor([0, 2])
one_hot(x, vocab_size)

 运行结果:

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])

在本例中,该函数的任务是将输入的文本使用one_hot编码。其中,x是一个vector,代表一个二字词语,其中的0和2代表汉字(在程序上文定义的字典中)所对应的数字。vocab_size是字典大小,即在该程序中所考虑的汉字总个数。n_class是one_hot编码中所考虑的类别数,在本例中等于vocab_size。

res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) :生成了两行,vocab_size列的零矩阵

res.scatter_(1, x.view(-1, 1), 1) :在res中,将1,按照dim=1(即不改行改列)的方向,根据[[0],[2]]所指示的位置,放入res中。(比如,x中的0,代表要放入第0列;而0本身处于第0行,所以是第0行中的第0列。)

 

参考文献

[1] 循环神经网络的从零开始实现. 原书作者:阿斯顿·张、李沐、扎卡里 C. 立顿、亚历山大 J. 斯莫拉以及其他社区贡献者. 原书名称:动手学深度学习Pytorch版

[2] PyTorch官方文档

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值