Pytorch中的torch.gather函数的理解
Pytorch中的torch.gather函数
pytorch比tensorflow更加编程友好,准备用pytorch试着做一些实验。
先看一下简单的用法示例代码,然后结合官方示例来解读:
b = torch.Tensor([[1,2,3],[4,5,6]])
# 1 2 3
# 4 5 6
print b
index_0 = torch.LongTensor([[1],[2]])
# [[1], 元素对应行标为[[0], 列标为[[0],
# [2]] [1]] [0]]
# index_0[0][0] == 1, index_0[1][0] == 2
print (torch.gather(b, dim=1, index=index_0))
#dim=1, out[i][j][k] = input[i][index[i][j][k]][k]
# [[1], 替换列标 [[0], out元素对 [[0], 列标为[[1],
# [2]] [0]] 应b的行标 [1]] [2]]
# out[0][0]= b[0][1]= 2 ,
# out[1][0]= b[1][2]= 6
index_1 = torch.LongTensor([[0,1],[2,0]])
# [[0,1], 元素对应行标为[[0,0], 列标为[[0,1],
# [2,0]] [1,1]] [0,1]]
#index_1[0][0] == 0, index_1[0][1] == 1, index_1[1][0] == 2, index_1[1][1] == 0
print (torch.gather(b, dim=1, index=index_1))
#dim=1, out[i][j][k] = input[i][index[i][j][k]][k]
# [[0,1], 替换列标[[0,1], out元素对 [[0,0], 列标为[[0,1],
# [2,0]] [0,1]] 应b的行标 [1,1]] [2,0]]
# out[0][0]= b[0][0]= 1, out[0][1]= b[0][1]= 2,
# out[1][0]= b[1][2]= 6, out[1][1]= b[1][0]= 4,
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
# [[0,1,1], 元素对应行标为[[0,0,0], 列标为[[0,1,2],
# [0,0,0]] [1,1,1]] [0,1,2]]
#...
print (torch.gather(b, dim=0, index=index_2))
#dim=0, out[i][j][k] = input[index[i][j][k]][j][k]
# [[0,1,1],替换行标[[0,0,0], out元素对 [[0,1,1], 列标为[[0,1,2],
# [0,0,0]] [1,1,1]] 应b的行标 [0,0,0]] [0,1,2]]
# out[0][0]= b[0][0]= 1,out[0][1]= b[1][1]= 5, out[0][2]= b[1][2]= 6 ,
# out[1][0]= b[0][0]= 1,out[1][1]= b[0][1]= 2, out[1][2]= b[0][2]= 3 ,
输出结果:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
tensor([[2.],
[6.]])
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
结合上面的例子来看官方解读及示例,官方解读是给了三个公式:
torch.gather(input, dim, index, out=None) → Tensor
'''
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
'''
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
可以看出index的形状和input的维度是一致的,都是二维的,里面的index元素数值不能超过input的界限,比如行的不能超过1,列的不能超过2。
理解了这几个式子也就记住了这个方法的用法。