pytorch中scatter函数的理解

文章参考:https://blog.csdn.net/weixin_45547563/article/details/105311543?spm=1001.2101.3001.6650.4&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-4-105311543-blog-104308528.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-4-105311543-blog-104308528.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=9

参考文章讲得很清楚,但只举了一个例子,本文又扩展补充了一下~

一、scatter函数简介

scatter_(input, dim, index, src):将src中数据根据index中的索引按照dim的方向填进input。可以理解成放置元素或者修改元素

dim:沿着哪个维度进行索引

index:用来 scatter 的元素索引

src:用来 scatter 的源元素,可以是一个标量或一个张量

填充规则如下:

# dim=0时,逐列(j)进行"行填充"
x[ index[i][j] ] [j] = src[i][j]

# dim=1时,逐行(i)进行"列填充"
x [i] [ index[i][j] ] = src[i][j]

二、举例

我们还是用参考文章中的例子来展示:


x = torch.rand(2, 5)
 
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
 
torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src)

1.当dim=0时:

因为index矩阵中最大元素为2,所以答案矩阵(记为x)的行数(从0开始索引)为2;

因为src矩阵(或index矩阵)列数为4(从0开始索引),所以答案矩阵的列数为4;

得到的矩阵如下:

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

根据dim=0时的填充原则:x[ index[i][j] ] [j] = src[i][j]

index[i][j]对应的数代表了x的行数,j代表了x的列数, 该行该列中x对应的元素值即为src第i行第j列对应的元素值

例如index[0][1]=1

因此x[1][1] = 0.3340 (x就是我们的答案矩阵哦)

0

0

0

0

0

0

0.3340

0

0

0

0

0

0

0

0

同理index[1][2] = 0,因此x[0][2] = src[1][2] (随便举的例子,实际上试验index中的任何一个元素都可以)

0

0

0.0074

0

0

0

0.3340

0

0

0

0

0

0

0

0

其他的也是一样,比如index[1][3] = 1,因此x[1][3] = src[1][3]

0

0

0.0074

0

0

0

0.3340

0

0.0943

0

0

0

0

0

0

最终可以得到x矩阵

2.当dim=1时

dim=1时,逐行进行"列填充",也就是x [i] [ index[i][j] ] = src[i][j]

index[i][j]对应的数字变成了答案矩阵x的列索引,i作为答案矩阵x的行索引,其对应的元素值为src[i][j]

因此答案矩阵的最大列数不超过index的最大元素值,最大行数不超过src矩阵的行数

举两个例子:

三、scatter函数在独热编码中的应用(不用看)

这部分是我害怕自己忘了书里的代码含义,所以记下来给自己看的~

target_onehot = torch.zeros(target.shape[0], 10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)

target.shape[0]代表了葡萄酒数据集中的数据总条数,10代表葡萄酒的类别总数

因此此处src是一个4898行,10列的tensor

target.unsqueeze(1)是我们的index,它对应了一个4898行,1列的tensor

target_onehot的每一行记录按顺序对应着target.unsqueeze(1)的每一行记录,假如index[i][j]=6,

那么target_onehot[i][6]=1,这就和独热编码对上了~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值