【PyTorch】scatter(dim, index, source)和scatter_(dim, index, src)的使用举例

参考链接: scatter(dim, index, source) → Tensor
参考链接: scatter_(dim, index, src) → Tensor

函数功能说明:

将src中的元素按照指定维度和指定索引写入到self自身张量中.

写入的规则是这样的:
依次遍历index,比如当前遍历到index[i][j][k],
然后从src中相应位置取出一个元素,即取出元素src[i][j][k],
将该元素写入到src,
在src中写入的位置是类似于(i,j,k),
但是在dim维度上是index[i][j][k].
举个例子如果dim == 0:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
举个例子如果dim == 1:
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
举个例子如果dim == 2:
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2





补充:
self, index 和 src三者的维度个数必须相同.
如果不给定src张量,也可以给定一个value浮点数.
index内的数值范围是[0, self.size(dim) - 1],
否则索引越界了,并且index的值在dim维度上必须不能有重复.

在所有维度上,index的维度长度都不得大于src,否则索引越界.

除了在dim维度上,index.size(d) <= self.size(d).
笔者注:引文index在dim维度上的值不能有重复,
因此index.size(dim) <= self.size(dim)也同样成立.

在这里插入图片描述

代码实验举例:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000023BAB59D330>
>>>
>>> x = torch.rand(2, 5)
>>> x
tensor([[0.9817, 0.9880, 0.8879, 0.3911, 0.8532],
        [0.2367, 0.6074, 0.6374, 0.7830, 0.1322]])
>>>
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[0.9817, 0.6074, 0.6374, 0.3911, 0.8532],
        [0.0000, 0.9880, 0.0000, 0.7830, 0.0000],
        [0.2367, 0.0000, 0.8879, 0.0000, 0.1322]])
>>>
>>>
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[0.0000, 0.0000, 1.2300, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.2300]])
>>>
>>>
>>>
>>> torch.zeros(3, 5).scatter_(1, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Invalid index in scatter at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:504
>>> x
tensor([[0.9817, 0.9880, 0.8879, 0.3911, 0.8532],
        [0.2367, 0.6074, 0.6374, 0.7830, 0.1322]])
>>> torch.zeros(3, 5).scatter_(1, torch.tensor([[0, 1, 2], [2, 0, 1]]), x)
tensor([[0.9817, 0.9880, 0.8879, 0.0000, 0.0000],
        [0.6074, 0.6374, 0.2367, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
>>>
>>> # 以上说明index参数在指定dim维度上不可重复
>>>
>>>

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值