pytorch函数之scatter()和scatter_()

前言

这两个函数,理清楚的人很清楚,不清楚的人很不清楚,建议直接看2.举例

官方文档

scatter_()
'官方定义'
scatter(input, dim, index, src) → Tensor

实际使用:如下面

input.scatter_(dim, index, src) → Tensor
'Or'
input.scatter(dim, index, src) → Tensor
'区别是scatter_函数不会回滚,使用后返回的就是更改后的input。而scatter是在内存中生成另外一个对象,不会覆盖原input'
  • input: 我们需要插入数据的起源tensor;也就是想要改变内部的tensor
  • dim:我们想要从哪个维度去改input数据
  • index:给出改的元素索引,也就是位置,说在“坐标”可能好理解一点。
  • src:准备好的插入到input中指定位置的数据。

总结input.scatter_(dim, index, src):从【src源数据】中获取的数据,按照【dim指定的维度】和【index指定的位置】,替换input中的数据。

2. 举例

先看代码

batch_size = 2
hidden_size = 8

src = torch.rand(batch_size, hidden_size)
input_ = torch.zeros(batch_size+1, hidden_size)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]])

print('src\n',src)
print('index\n',index)
print('input_\n',input_)
print('ans:\n',input_.scatter_(0, index, src))
'''
src
 tensor([[0.3304, 0.5643, 0.2362, 0.1929, 0.2400, 0.6672, 0.5217, 0.4471],
        [0.0433, 0.2996, 0.9913, 0.4336, 0.8540, 0.8522, 0.0408, 0.1014]])
index
 tensor([[0, 1, 2, 0, 0, 1, 1, 2],
        [2, 0, 0, 1, 2, 1, 1, 1]])
input_
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
ans:
 tensor([[0.3304, 0.2996, 0.9913, 0.1929, 0.2400, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5643, 0.0000, 0.4336, 0.0000, 0.8522, 0.0408, 0.1014],
        [0.0433, 0.0000, 0.2362, 0.0000, 0.8540, 0.0000, 0.0000, 0.4471]])
'''

比如上述代码,dim=0代表按行赋值,那么index[1][3]=1,代表更改input中的[1]行;另外,index[1][3]对应的src[1][3]的值是0.4336index[1][3]的[3]列,因此是把0.4336这个数值放入input中的[1][3]的位置。

如果还是不太清楚,我们把dim=1设定为按列

src = torch.rand(batch_size, hidden_size).transpose(0,1)
input_ = torch.zeros(batch_size+1, hidden_size).transpose(0,1)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]]).transpose(0,1)

print('src\n',src)
print('index\n',index)
print('input_\n',input_)
# print('ans:\n',input_.scatter_(0, index, src))
print('ans:\n',input_.scatter_(1, index, src))
'''
src
 tensor([[0.3504, 0.3369],
        [0.1163, 0.3850],
        [0.5554, 0.5531],
        [0.0440, 0.2904],
        [0.2444, 0.6650],
        [0.4698, 0.5640],
        [0.1331, 0.5830],
        [0.0408, 0.8508]])
index
 tensor([[0, 2],
        [1, 0],
        [2, 0],
        [0, 1],
        [0, 2],
        [1, 1],
        [1, 1],
        [2, 1]])
input_
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
ans:
 tensor([[0.3504, 0.0000, 0.3369],
        [0.3850, 0.1163, 0.0000],
        [0.5531, 0.0000, 0.5554],
        [0.0440, 0.2904, 0.0000],
        [0.2444, 0.0000, 0.6650],
        [0.0000, 0.5640, 0.0000],
        [0.0000, 0.5830, 0.0000],
        [0.0000, 0.8508, 0.0408]])
'''

同上,举例: dim=1代表按列赋值, index[4][1]=2,代表行是[4]列是[2],说明是把src[4][1]的值,赋值给input[4][2]

  • 10
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值