gather函数真的是让人头大,看了一会,趁着还有些印象赶紧记下来。
首先说明一下,gather函数的input与index必须有相同的维度,例如你输入的数据是二维的,那么index也必须是二维的,但他们的shape可以不同,最终输出的结果与输入的index相同。接下来首先通过二维张量举例。
import torch
input = torch.arange(16).view(4, 4)
"""
[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
"""
接下来创建一个index,当然这个index也要是二维的。
index = torch.LongTensor([[0, 2, 2, 3]])
首先讨论当dim=0的情况
output = input.gather(dim=0, index)
"""
[[0, 9, 10, 15]]
"""
当dim=0的时候,可以看作按照row来选择,其中index里的数值则表示选择了哪一行,举个栗子,原index的数值为[0.2.2.3],那么表示这四个数分别来自第0行、第2行、第2行、第3行,好了,那接下来行确定了,具体选择哪个数字呢,那就要根据原index中数字所在的位置确定了,比如说[0.2.2.3]中0的索引就是0,第一个2的索引为1,第二个2索引的位置为2,3的索引为3,将这两部分拼凑在一起吧,于是我们得到了最终想要的结果。
紧接着dim=1的情况也好说明了吧
index = torch.LongTensor([[0, 2, 2, 3]])
output = input.gather(dim=1, index)
"""
[[0 6 10 15]]
"""
同理,只要我们将index里面的数据按照列来选择就好了,列确定之后,具体的数值依旧是按照原index里面的索引位置确定。
当二维的我们知道怎么回事了,现在该讨论三维数据下的情况了。老规矩,先建个三维张量吧。当然我们的index也要变成三维的了。
c=torch.arange(24).view(3,2,4)
'''
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
'''
index=tensor([[[0, 0, 0, 0],
[1, 1, 1, 1]],
[[1, 1, 1, 1],
[2, 2, 2, 2]],
[[2, 2, 2, 2],
[2, 2, 2, 2]]])
当数据变成三维的时候,这个时候多出了一个维度的数据,这时候该怎么理解呢。
dim=0, 从一个batch中选择数据。
dim=1, 从一个batch中选择某些行。
dim=2,从一个batch中选择某些列。
举个栗子。当dim=0时,里面的数字代表了batch,所以说我们先看最上面的一行张量[0,0,0,0],这句话代表,这行的四个数字全部来自input的第0个batch块,也就是input中0~7的这部分,那具体的选择哪个数字呢,还是老规矩,确定每个数字的索引值就好啦,[0,0,0,0]在index其所在batch块中的索引值为[0,0] 、[0,1]、[0,2]、[0,3],我们再将我们所获取的信息拼接起来,就得到了我们想要的结果了。听起来有点绕,但是我们只要根据dim所给的值确定对应的维度,再根据位置获取对应元素,理解起来就容易多了。
c.gather(dim=0,index=index)
'''
c = tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
index=tensor([[[0, 0, 0, 0],
[1, 1, 1, 1]],
[[1, 1, 1, 1],
[2, 2, 2, 2]],
[[2, 2, 2, 2],
[2, 2, 2, 2]]])
output=tensor([[[ 0, 1, 2, 3],
[12, 13, 14, 15]],
[[ 8, 9, 10, 11],
[20, 21, 22, 23]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
'''
再来个dim=1的例子:
input=tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
index=tensor([[[1, 0, 0, 1],
[0, 0, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[0, 0, 0, 0],
[1, 1, 1, 1]]])
output=tensor([[[ 4, 1, 2, 7],
[ 0, 1, 2, 3]],
[[12, 13, 14, 15],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
有了上面的例子,这次就应该更好理解了一些,既然这次时dim=1,那么就说明batch块是确定的了,这不就由回到二维的情况了嘛。
也不差dim=2的例子了。
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[[1, 0, 0, 1],
[0, 0, 0, 0]],
[[0, 1, 2, 3],
[1, 1, 1, 1]],
[[0, 0, 0, 0],
[1, 1, 1, 1]]])
tensor([[[ 1, 0, 0, 1],
[ 4, 4, 4, 4]],
[[ 8, 9, 10, 11],
[13, 13, 13, 13]],
[[16, 16, 16, 16],
[21, 21, 21, 21]]])