因为自己遇到了将 Tensor
[1000,2] 截断为 [200,2]
的需求,故在网络上寻找对策。
先是看到了 tensor.gather() , 确实是不错的函数,但为了此需求稍显复杂。
终于发现了torch.index_select() ,好用到爆炸。
献上官方文档的例子:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices) # 0指的是列,按indices[0,2]取得就是x的第一、第三行。
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices) # 1指的是行,按indices[0,2]取得就是x的第一、第三列。
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
非常好理解是吧。
然后我就用这个函数开心的解决了我的问题:
x = torch.randn(1000,2)
ind = []
for i in range(200): # 取tensor的前200行
ind.append(i)
indices = torch.LongTensor(ind).to(device)
out_rnn = torch.index_select(x, 0, indices).to(device)