import torch
# 创建数据
feat = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 创建索引张量
index = torch.tensor([[0, 2, 2, 0], [2, 1, 0, 2]])
# 在第0维度上根据索引收集数据
result = feat.gather(dim=0, index=index)
print(result)
gather函数按照指定维度和索引信息提取数据,
当dim=0时,这时可以看作从行维度提取信息,此时index里面的数值都是表示行的信息,其所在的列的位置表示列的信息。举例:index[0][1]=2, 此时2表示第3行,2在index第2列(列位置信息)所以index[0][1]=2 表示取feat矩阵第3行第2列为10。依此类推index中的所有元素。
tensor([[ 1, 10, 11, 4],
[ 9, 6, 3, 12]])
当dim=1时,这时可以看作从列维度提取信息,此时index里面的数值都是表示列的信息,其所在的行的位置表示行的信息。举例:index[0][1]=2, 此时2表示第3列,2在index第1行(列位置信息)所以index[0][1]=2 表示取feat矩阵第3列第1行为3。依此类推index中的所有元素。
tensor([[1, 3, 3, 1], [7, 6, 5, 7]])