pytorch中scatter_介绍

1.官方文档中的介绍

scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原来的。

scatter_(dim, index, src) 的参数有 3 个

  • dim:在哪个维度进行变换
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素

具体的转化关系可以参考下图。
在这里插入图片描述
注意,这个里面的i,j,k可以说都是相同的量。

2.实例介绍

利用官网中的例子

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

这里我们来看看,首先声明了一个2*5的矩阵,里面的值是[0,1),然后我们使用了scatter_,可以看出dim=0,对应上面的公式,我们可以得到。
self[index[0,0],0] = self[0,0] = src[0,0] = 00.3992
(其中i=0,j=0)
self[index[0,1],1] = self[1,1] = src[0,1] = 0.2908
(其中i=0,j=1)
以此类推…

3.实现one-hot

y_train = torch.Tensor(y_train).long()
y_train_onehot = y_train
y_train_onehot = y_train.view(-1, 1)
y_train_onehot = torch.zeros(y_train.size(0), 10).scatter_(1, y_train_onehot, 1).long()

数据说明:
y_train结构是(10000),是有10000个训练集,每个都进行了分类,一共10类(类似于[1,2,3,5,0,6,4,9…])
代码的思想就是首先生成一个全0的矩阵,然后通过对其列上面的值变换为1从而实现one-hot。

当然现在不用这个啦,pytorch中有one_hot函数

torch.nn.functional.one_hot(tensor, num_classes=-1) → LongTenso

在这里插入图片描述
对于上面的数据我们可以直接用

y_train = torch.Tensor(y_train).long()
y_train = torch.nn.functional.one_hot(y_train,10)

输出的结果为
在这里插入图片描述

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值