import torch
a = torch.tensor([0,1,2,3,4],[5,6,7,8,9])
a.shape = [2,5]
假如2为batchsize,5为数据,目标: a循环拼接自己3次成[2,3,5],
[[[0,1,2,3,4],
[0,1,2,3,4],
[0,1,2,3,4]
]],
[[5,6,7,8,9],
[5,6,7,8,9],
[5,6,7,8,9]
]]
通过下面命令可以实现:
a.unsqueeze(1) = [2,1,5]
a.repeat(1,3,1) = [2,3,5]
------------------------------
按行按索引取值
a =torch.rand(5,4)
b = torch.tensor([0,0,1,1,2]).view(-1,1)
a.gather(1,b)