pytorch scatter和gather函数理解

本文介绍了PyTorch中的scatter_和gather函数,详细解释了它们的参数和用法。scatter函数用于将源张量src的元素按照指定的index散列到目标张量中,而gather函数则按照index从输入张量中选择元素形成新的张量。此外,还展示了如何使用scatter函数将向量转化为one-hot形式。
摘要由CSDN通过智能技术生成

Scatter函数

scatter_(dim, index, src, reduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

简单理解就是将 src 张量中的元素散落到 self 张量中,具体选择哪个元素,选择的元素散落到哪个位置由index张量决定,具体的映射规则为:

# 其中 i,j,k 为index张量中元素坐标。
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
参数
  • dim(int) 指index数组元素替代的坐标(dim = 0 替代src中的横坐标)
  • index (LongTensor) 可以为空,最大与src张量形状相同
  • src(Tensor or float) 源张量
  • reduce 聚集函数(src替换元素与self中被替换元素执行的操作,默认是替代,可以进行add,multiply等操作)

具体例子:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值