y = torch.index_select(x, dim=1, index=index)
该函数是按照维度对原张量进行切分
index 是切分的索引,通常是一个数字型列表
x是待切分的张量
假设 x = torch.rand(3,2,1) 且 index = torch.LongTensor([0])
tensor([[[0.3048],
[0.0140]],
[[0.6699],
[0.3395]],
[[0.1088],
[0.9452]]])
tensor([0]) 注意这里的index 必须为longTensor类型
y = torch.index_select(x, dim=1, index=index)
将张量x 按照第二个维度(dim从0开始算)进行切分,搜索到第一个元素(index=0)
得到
tensor([[[0.3048]],
[[0.6699]],
[[0.1088]]])