Pytorch基础 - 8. scatter() / scatter_() 函数

本文介绍了PyTorch中的scatter()和scatter_()函数,这两个函数用于根据索引映射创建或修改Tensor。文章详细解释了函数的参数,提供了2维示例,并阐述了其在one-hot编码中的应用。通过实例,读者可以理解如何沿着特定维度进行索引操作,以及当源数据为标量时的情况。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1. scatter() 定义和参数说明

2. 示例和详细解释

3. scatter() 常见用途


1. scatter() 定义和参数说明

scatter() 或 scatter_() 常用来返回根据index映射关系映射后的新的tensor。其中,scatter() 不会直接修改原来的 Tensor,而 scatter_() 直接在原tensor上修改。

官方文档:torch.Tensor.scatter_ — PyTorch 2.0 documentation

 参数定义:

  • dim:沿着哪个维度进行索引
  • index:索引值
  • src:数据源,可以是张量,也可以是标量

简言之 scatter() 是通过 src 来修改另一个张量,修改的元素值和位置由 dim 和 index 决定

2. 示例和详细解释

在官方文档中,给出了3维tensor的具体操作说明,看起来很蒙,没关系继续往下看

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

 接下来的示例,我们以2维为例,那上面的公式简化为如下,

self[index[i][j]][j] = src[i][j]  # if dim == 0
self[i][index[i][j]] = src[i][j]  # if dim == 1

示例:将全零的张量,根据index和scr进行值的变化

src = torch.arange(1, 11).reshape((2, 5))
# src: tensor([[0.8351, 0.2974, 0.9028, 0.4250, 0.0370],
#              [0.4564, 0.6832, 0.6854, 0.6056, 0.7118]])
    
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 4, 2, 3]])
dist = torch.zeros(2, 5, dtype=src.dtype).scatter(1, index, src)
    
# dist: tensor([[0.0370, 0.2974, 0.9028, 0.0000, 0.0000],
#               [0.4564, 0.6832, 0.6056, 0.7118, 0.6854]])

将上述张量使用表格表示: 

当 dim = 1时,dist[i

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值