Understand torch.scatter

本文详细解释了PyTorch中`scatter_(dim,index,src)`函数的工作原理,包括其作为就地操作如何改变输入张量,以及在不同维度下(如行和列)的图形示例。特别强调了当src为标量时的广播机制。
摘要由CSDN通过智能技术生成

1. Official Documentation

First, note that scatter_() is an inplace function, meaning that it will change the value of input tensor.

The official document scatter_(dimindexsrc) → Tensor tells us that parameters include the dim, index tensor, and the source tensor. dim specifies where the index tensor is functioning, and we will keep the other dimensions unchanged. And as the function name suggests, the goal is to scatter values in the source tensor to the input tensor self. What we are going to do is to loop through the values in the source tensor, find its position in the input tensor, and replace the old one.

Note that src can also just be a scalar. In this case, we would just scatter this single value according to the index tensor.

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

2. Graphical Diagram for dim=0

For simplicity, let us consider two-dimensional matrices here. Let us first understand dim.

When dim=0, the index of rows will be based on the index tensor, and the index of columns will not change, and this means the jth column of the source tensor will only be scattered to the jth column of the input tensor. Let us try to manually update the input tensor step by step using the following example.

import torch
import numpy as np
src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
print(src)> tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])input_tensor = torch.zeros(3, 5)
print(input_tensor)> tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])index_tensor = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index_tensor)> tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])## try to manually work out the result 
dim = 0
print(input_tensor.scatter_(dim, index_tensor, src))> ...

Step 1: scatter the 1st column of src to the 1st column of input_tensor. Matching with the 1st column of index tensor. We would scatter 1 to row 0, scatter 6 to row 2.

Step 2: scatter the 2nd column of src to the 2nd column of input_tensor. Matching with the 2nd column of index tensor. We would scatter 2 to row 1, scatter 7 to row 0.

Step 3/4/5: do scattering in a similar way. In the end, we would get the following diagram.

Check it in python. Well done!

> tensor([[ 1.,  7.,  8.,  4.,  5.],
        [ 0.,  2.,  0.,  9.,  0.],
        [ 6.,  0.,  3.,  0., 10.]])

Note that the values in the index tensor represent the row indices when dim=0, so it implicitly suggests that the max value of the index tensor should be smaller the number of rows in the input. Generally speaking, the following should be True.

input_tensor.shape[dim] > index_tensor.max().item()

3. Graphical Diagram for dim = 1

Similarly, we can work it out when dim=1. Let us try the following example.

src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
input_tensor = torch.zeros(3, 5)
index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
dim = 1
print(input_tensor.scatter_(dim, index_tensor, src))

Step 1: scatter the 1st row of src to the 1st row of input_tensor. 1 to col3, 2 to col0, 3 to col2, 4 to col1, 5 to col4.

Step 2: scatter the 2nd row of src to the 2nd row of input_tensor.

Note that there are two 1’s in the 2nd row of index_tensor. To make the updation clearer, I would split this step into two substeps.

Step 2.1: scatter 6 to col2, 7 to col0, 8 to col1, 9 to col3.

Step 2.2: scatter 10. The corresponding index is 1, but 8 has already been there. What we would do is to replace 8 with 10.

Done! Let’s check the results. Correct! 😄

> tensor([[ 2.,  4.,  3.,  1.,  5.],
        [ 7., 10.,  6.,  9.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

4. Graphical Diagram for a Trickier Example

Finally, let’s try a trickier example where the src is a value and the size of the index tensor is smaller than the input tensor for dim != dim.

Note that the dimension of the input tensor and the index tensor should always be the same, and this is why you may sometimes see unsqueeze() in others’ code. Also, note that the index tensor and the input tensor should have the same size on the specified dim.

input_tensor = torch.from_numpy(np.arange(1, 16)).float().view(3, 5) # dim is 2# unsqueeze to have dim = 2
index_tensor = torch.tensor([4, 0, 1]).unsqueeze(1) 
src = 0
dim = 1
print(input_tensor.scatter_(dim, index_tensor, src))

Note that when src is a scalar, we are actually using the broadcasted version which has the same size as the index tensor.

dim = 1, so we do scattering row by row. For row1, we would scatter 0 to col4; for row2, we would scatter 0 to col0; for row3, we scatter 0 to col1.

Checking the result — great job! 🌟

> tensor([[ 1.,  2.,  3.,  4.,  0.],
        [ 0.,  7.,  8.,  9., 10.],
        [11.,  0., 13., 14., 15.]])

Hope this tutorial will help you better understand torch.scatter_()!

References:

[1] torch.Tensor — PyTorch 2.2 documentation

  • 26
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张博208

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值