【Pytorch】scatter函数详解

在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:

target.scatter(dim, index, src)

其中各变量及参数的说明如下:

  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐列进行行填充
  • index: 按照轴方向,在target张量中需要填充的位置

为了保证scatter填充的有效性,需要注意:
(1)target张量在dim方向上的长度不小于source张量,且在其它轴方向的长度与source张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
(2)index张量的shape一般与source ,从而定义了每个source元素的填充位置。这里的一般是指broadcast机制下的例外情况。

下面以一个实际的案例来观察scatter函数:

import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
b_= b.scatter(dim=0, index=torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
print(b_)

# tensor([[0, 6, 0, 0, 9],
#        [0, 0, 2, 8, 0],
#        [5, 1, 7, 0, 4]])

整个函数的操作过程见下面的示意图。因为设定了dim=0,所以会逐列将source中的元素按照index中的位置信息,放入target张量中。
在这里插入图片描述
scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:

labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
#        [0, 0, 0, 1, 0]])
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值