前言
这两个函数,理清楚的人很清楚,不清楚的人很不清楚,建议直接看2.举例
官方文档
scatter_()
'官方定义'
scatter(input, dim, index, src) → Tensor
实际使用:如下面
input.scatter_(dim, index, src) → Tensor
'Or'
input.scatter(dim, index, src) → Tensor
'区别是scatter_函数不会回滚,使用后返回的就是更改后的input。而scatter是在内存中生成另外一个对象,不会覆盖原input'
- input: 我们需要插入数据的起源
tensor
;也就是想要改变内部的tensor
- dim:我们想要从哪个维度去改
input
数据 - index:给出改的元素索引,也就是位置,说在“坐标”可能好理解一点。
- src:准备好的插入到
input
中指定位置的数据。
总结:input.scatter_(dim, index, src)
:从【src
源数据】中获取的数据,按照【dim
指定的维度】和【index
指定的位置】,替换input
中的数据。
2. 举例
先看代码
batch_size = 2
hidden_size = 8
src = torch.rand(batch_size, hidden_size)
input_ = torch.zeros(batch_size+1, hidden_size)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]])
print('src\n',src)
print('index\n',index)
print('input_\n',input_)
print('ans:\n',input_.scatter_(0, index, src))
'''
src
tensor([[0.3304, 0.5643, 0.2362, 0.1929, 0.2400, 0.6672, 0.5217, 0.4471],
[0.0433, 0.2996, 0.9913, 0.4336, 0.8540, 0.8522, 0.0408, 0.1014]])
index
tensor([[0, 1, 2, 0, 0, 1, 1, 2],
[2, 0, 0, 1, 2, 1, 1, 1]])
input_
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
ans:
tensor([[0.3304, 0.2996, 0.9913, 0.1929, 0.2400, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5643, 0.0000, 0.4336, 0.0000, 0.8522, 0.0408, 0.1014],
[0.0433, 0.0000, 0.2362, 0.0000, 0.8540, 0.0000, 0.0000, 0.4471]])
'''
比如上述代码,dim=0
代表按行赋值,那么index[1][3]=
1,代表更改input
中的[1]行;另外,index[1][3]
对应的src[1][3]
的值是0.4336
;index[1][3]
的[3]列,因此是把0.4336
这个数值放入input
中的[1][3]的位置。
如果还是不太清楚,我们把dim=1
设定为按列
src = torch.rand(batch_size, hidden_size).transpose(0,1)
input_ = torch.zeros(batch_size+1, hidden_size).transpose(0,1)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]]).transpose(0,1)
print('src\n',src)
print('index\n',index)
print('input_\n',input_)
# print('ans:\n',input_.scatter_(0, index, src))
print('ans:\n',input_.scatter_(1, index, src))
'''
src
tensor([[0.3504, 0.3369],
[0.1163, 0.3850],
[0.5554, 0.5531],
[0.0440, 0.2904],
[0.2444, 0.6650],
[0.4698, 0.5640],
[0.1331, 0.5830],
[0.0408, 0.8508]])
index
tensor([[0, 2],
[1, 0],
[2, 0],
[0, 1],
[0, 2],
[1, 1],
[1, 1],
[2, 1]])
input_
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
ans:
tensor([[0.3504, 0.0000, 0.3369],
[0.3850, 0.1163, 0.0000],
[0.5531, 0.0000, 0.5554],
[0.0440, 0.2904, 0.0000],
[0.2444, 0.0000, 0.6650],
[0.0000, 0.5640, 0.0000],
[0.0000, 0.5830, 0.0000],
[0.0000, 0.8508, 0.0408]])
'''
同上,举例: dim=1
代表按列赋值, index[4][1]=2
,代表行是[4]列是[2],说明是把src[4][1]
的值,赋值给input[4][2]