import torch
a=torch.tensor([[1,2,3],
[4,5,6],
[7,8,9]])
print(a.shape) # 输出: torch.Size([3, 3])
aa=torch.stack((a,a),dim=1) #stack 会在dim插入维度(维度增加),大小为堆的数目
print(aa.shape) # 输出: torch.Size([3,2,3])
import torch
a=torch.tensor([[1,2,3],
[4,5,6],
[7,8,9]])
print(a.shape) # 输出: torch.Size([3, 3])
aa=torch.stack((a,a),dim=1) #stack 会在dim插入维度(维度增加),大小为堆的数目
print(aa.shape) # 输出: torch.Size([3,2,3])