>>> import torch
>>> T1=torch.randn([2,3])
>>> T2=torch.randn([2,3])
>>> T1
tensor([[-0.5566, -1.2475, -0.5865],
[-2.3513, 0.4147, -1.0349]])
>>> T2
tensor([[-0.1764, -0.8736, 0.9022],
[-1.1520, -0.1529, -0.1760]])
>>> Tcat=torch.cat([T1,T2],dim=0)
>>> Tcat
tensor([[-0.5566, -1.2475, -0.5865],
[-2.3513, 0.4147, -1.0349],
[-0.1764, -0.8736, 0.9022],
[-1.1520, -0.1529, -0.1760]])
>>> Tcat.shape
torch.Size([4, 3])
>>> Tstack=torch.stack([T1,T2],dim=0)
>>> Tstack
tensor([[[-0.5566, -1.2475, -0.5865],
[-2.3513, 0.4147, -1.0349]],
[[-0.1764, -0.8736, 0.9022],
[-1.1520, -0.1529, -0.1760]]])
>>> Tstack.shape
torch.Size([2, 2, 3])
torch.stack()和torch.cat()区别
最新推荐文章于 2024-09-15 18:10:00 发布