torch.stack(c, dim=1) 使c沿着维度1进行堆叠,这样就可以使c维度进行变化。
1.准备数据
import torch
c = []
b = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# stack()函数要求输入为Tensors,而不是Tensor,这里放入两次tensor
c.append(b)
c.append(b)
# print(c.shape) # [2, 3, 4]
2.用torch.stack改变维度
a = torch.stack(c, dim=1)
print(a.shape) # [3, 2, 4]