参考文章讲得很清楚,但只举了一个例子,本文又扩展补充了一下~
一、scatter函数简介
scatter_(input, dim, index, src):将src中数据根据index中的索引按照dim的方向填进input。可以理解成放置元素或者修改元素
dim:沿着哪个维度进行索引
index:用来 scatter 的元素索引
src:用来 scatter 的源元素,可以是一个标量或一个张量
填充规则如下:
# dim=0时,逐列(j)进行"行填充"
x[ index[i][j] ] [j] = src[i][j]
# dim=1时,逐行(i)进行"列填充"
x [i] [ index[i][j] ] = src[i][j]
二、举例
我们还是用参考文章中的例子来展示:
x = torch.rand(2, 5)
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src)
![](https://img-blog.csdnimg.cn/img_convert/93bf80f2e7fb29b33b46f79e7d2e1fc3.png)
1.当dim=0时:
因为index矩阵中最大元素为2,所以答案矩阵(记为x)的行数(从0开始索引)为2;
因为src矩阵(或index矩阵)列数为4(从0开始索引),所以答案矩阵的列数为4;
得到的矩阵如下:
0 | 0 | 0 | 0 | 0 |
0 | 0 | 0 | 0 | 0 |
0 | 0 | 0 | 0 | 0 |
根据dim=0时的填充原则:x[ index[i][j] ] [j] = src[i][j]
index[i][j]对应的数代表了x的行数,j代表了x的列数, 该行该列中x对应的元素值即为src第i行第j列对应的元素值
例如index[0][1]=1
![](https://img-blog.csdnimg.cn/img_convert/331b834eba1b1044b1fc605bbce409ae.jpeg)
因此x[1][1] = 0.3340 (x就是我们的答案矩阵哦)
0 | 0 | 0 | 0 | 0 |
0 | 0.3340 | 0 | 0 | 0 |
0 | 0 | 0 | 0 | 0 |
同理index[1][2] = 0,因此x[0][2] = src[1][2] (随便举的例子,实际上试验index中的任何一个元素都可以)
0 | 0 | 0.0074 | 0 | 0 |
0 | 0.3340 | 0 | 0 | 0 |
0 | 0 | 0 | 0 | 0 |
其他的也是一样,比如index[1][3] = 1,因此x[1][3] = src[1][3]
0 | 0 | 0.0074 | 0 | 0 |
0 | 0.3340 | 0 | 0.0943 | 0 |
0 | 0 | 0 | 0 | 0 |
最终可以得到x矩阵
![](https://img-blog.csdnimg.cn/img_convert/163c44b4d03131ffe6433ea8ae4fdc41.png)
2.当dim=1时
dim=1时,逐行进行"列填充",也就是x [i] [ index[i][j] ] = src[i][j]
index[i][j]对应的数字变成了答案矩阵x的列索引,i作为答案矩阵x的行索引,其对应的元素值为src[i][j]
因此答案矩阵的最大列数不超过index的最大元素值,最大行数不超过src矩阵的行数
举两个例子:
![](https://img-blog.csdnimg.cn/img_convert/5a69d13d5ebd198389fa87e051cba170.png)
![](https://img-blog.csdnimg.cn/img_convert/119420f20eaddd81ebbd1360b09ea86a.png)
三、scatter函数在独热编码中的应用(不用看)
这部分是我害怕自己忘了书里的代码含义,所以记下来给自己看的~
target_onehot = torch.zeros(target.shape[0], 10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)
target.shape[0]代表了葡萄酒数据集中的数据总条数,10代表葡萄酒的类别总数
因此此处src是一个4898行,10列的tensor
target.unsqueeze(1)是我们的index,它对应了一个4898行,1列的tensor
target_onehot的每一行记录按顺序对应着target.unsqueeze(1)的每一行记录,假如index[i][j]=6,
那么target_onehot[i][6]=1,这就和独热编码对上了~