pytorch学习笔记之gather函数
作为一名pytorch的初学者,力争将学习过程中遇到的不懂函数尽量弄懂。俗话说“基础不牢,地动山摇”,可见夯实基础非常重要。最近学习过程中遇到不太懂的函数gather,查阅相关文档最终将这个函数弄明白,下面分享个人对gather函数理解。
import torch as t
print(t.__version__)
a= t.LongTensor([[1,3,5],[2,4,6]])
index1 = t.LongTensor([[0,1,0],[1,0,1]])
index2 = t.LongTensor([[0,1],[1,0]])
result1=t.gather(a,dim=0,index=index1)
result2=t.gather(a,dim=1,index=index2)
print(a)
print(result1)
print(result2)
输出结果:
1.2.0
tensor([[1, 3, 5],
[2, 4, 6]])
tensor([[1, 4, 5],
[2, 3, 6]])
tensor([[1, 3],
[4, 2]])
从输出结果中可以看出,gather的作用是,根据给定的方向(dim=0是按列,dim=1是按行)
将索引index中的值聚合起来。
’‘纸上学来终觉浅,绝知此事要躬行’‘,亲自动手尝试确实会有不一样的收获。
https://blog.csdn.net/mayou32215201/article/details/104295313