pytorch:scatter_

我终于想明白这个函数的意思啦!!!

scatter的意思:服务组件架构

data.scatter_(dim,index,src)

将src中数据根据index中的索引按照dim的方向填进data。

import torch
x=torch.tensor([[7, 2, 3,4, 5],
        [1.1, 2.2, 3.3, 4.4, 5.5]])
y=torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
print(x)
z=torch.zeros(3,5).scatter_(0,y,x)
print(z)
#输出
'''
tensor([[7.0000, 2.0000, 3.0000, 4.0000, 5.0000],
        [1.1000, 2.2000, 3.3000, 4.4000, 5.5000]])
tensor([[7.0000, 2.2000, 3.3000, 4.0000, 5.0000],
        [0.0000, 2.0000, 0.0000, 4.4000, 0.0000],
        [1.1000, 0.0000, 3.0000, 0.0000, 5.5000]])

'''

y=torch.zeros(3,5).scatter_(0,y,x)

这句话的意思是:

首先:创建一个(3,5)的〇为z,然后,将x中的值补充到〇中。

根据z进行补充,如果若干个数字填充到相同的位置,则按顺序,看最后是什么,则确定是什么。

[[0,1,2,0,0],[2,0,0,1,2]]

(0,0),(0,1),(0,2),(0,3),(0,4)

(1,0),(1,1),(1,2),(1,3),(1,4)

这是x数据的位置,这时候,要x数据的位置坐标,根据y更改,由于dim=0,即更改第0维度

(0,0),(1,1),(2,2),(0,3),(0,4)

(2,0),(0,1),(0,2),(1,3),(2,4)

x数据在z中位置就是上面坐标。如果有几个数据对应的坐标相同,则一层层覆盖,显示最后一个数据。

import torch
x=torch.tensor([[7, 2, 3,4, 5],
        [1.1, 2.2, 3.3, 4.4, 5.5]])
y=torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
print(x)
z=torch.zeros(3,5).scatter_(1,y,x)
print(z)
#输出
'''
tensor([[7.0000, 2.0000, 3.0000, 4.0000, 5.0000],
        [1.1000, 2.2000, 3.3000, 4.4000, 5.5000]])
tensor([[5.0000, 2.0000, 3.0000, 0.0000, 0.0000],
        [3.3000, 4.4000, 5.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
'''

(0,0),(0,1),(0,2),(0,3),(0,4)

(1,0),(1,1),(1,2),(1,3),(1,4)

这是x数据的位置,这时候,要x数据的位置坐标,根据y=[[0,1,2,0,0],[2,0,0,1,2]]更改,由于dim=1,即更改第1维度

(0,0),(0,1),(0,2),(0,0),(0,0)

(1,2),(1,0),(1,0),(1,1),(1,2)

三个数字的新坐标都是(0,0),则以最后一次也就是x中数据5占领此位置。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值