[pytorch函数合集]torch.index_select
这次讲解的函数是scatter_max(),该函数有4个参数,其中前三个分别为,(1)将要被操作的目标数组arr1(2)设定操作数组arr1哪个维度的参数dim(3)目录函数(我自己起的名)index
import torch
from torch_scatter import scatter_min, scatter_max, scatter_add, scatter_mean
arr1 = torch.randint(1,10,[3,4])
print(arr1)
out = torch.index_select(arr1, 0, torch.tensor([0, 0, 1, 2, 1]))
print(out)
如上如所示,第一个参数arr1是我们需要操作的数组;第二个参数是需要操作的维度,维度是0为按行操作,维度是1为按列操作;第三个参数为操作的目录,意为从arr1中需要摘录的数据,这里为:按顺序摘录arr1的第0,0,1,2,1行数据:
tensor([[4, 7, 8, 4],
[2, 4, 4, 4],
[9, 2, 5, 8]])
tensor([[4, 7, 8, 4],
[4, 7, 8, 4],
[2, 4, 4, 4],
[9, 2, 5, 8],
[2, 4, 4, 4]])