2021-03-22

记录一下 pytorch中scatter_的用法
scatter_(dim, index, src)
dim: 按哪个轴填充
index:src数据放到哪一行(dim=0时)或列(dim=1时)
src:源数组
return: 填充后的数组

代码:

>>> src = torch.rand(2,5)
>>> src 
tensor([[0.0763, 0.2093, 0.2751, 0.1032, 0.8062],
        [0.5539, 0.2218, 0.2150, 0.3601, 0.5296]])

>>> index = torch.tensor([[0,1,3,2,1],[1,2,3,1,4]])
>>> index
tensor([[0, 1, 3, 2, 1],
        [1, 2, 3, 1, 4]])

>>> var_dim0 = torch.zeros(5,5)
>>> var_dim0 
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
        
>>> var_pad = var_dim0.scatter_(dim=0, index = index, src =src )
>>> var_pad
tensor([[0.0763, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5539, 0.2093, 0.0000, 0.3601, 0.8062],
        [0.0000, 0.2218, 0.0000, 0.1032, 0.0000],
        [0.0000, 0.0000, 0.2150, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.5296]])

解释:
index的每个元素代表src的每个值放到var_dim0的哪一行(dim=0时),src元素放在var_dim0哪一列取决于src的这个元素在哪一列。
src[0,0]是0.0763,与其对应的index是index[0,0]为0(src与index元素是一一对应的),说明0.0763放到var_dim0的第0行,列数就是src这个元素所在列数,即0。所以0.0763放到var_dim0的[0,0]位置。更直观点就是dim=0时src的值只能放到var_dim0的对应列,而行数由对应的index确定。比如src第一列0.0736和0.5539,则这两个值只能放在var_dim0的第一列的某位置,0.0736放在第0行,0.5539放在index[1,0] = 1。dim=1时同理,只能把src对应行的值放到var_dim0的同一行,而列位置由index决定。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值