torch.
stack
(tensors, dim=0, out=None) → Tensor
torch.stack() 将序列连接,形成一个新的tensor结构,此结构中会增加一个维度。连接中的 每个tensor都要保持相同的大小。
参数:
tensors:需要连接的结构
dim:需要扩充的维度
output:输出的结构
例子:
import torch
l = []
for i in range(0,3):
x = torch.rand(2,3)
l.append(x)
print(l)
x = torch.stack(l,dim=0)
print(x.size())
z = torch.stack(l,dim=1)
print(z.size())
output:
[tensor([[0.3615, 0.9595, 0.5895],
[0.8202, 0.6924, 0.4683]]), tensor([[0.0988, 0.3804, 0.5348],
[0.0712, 0.4715, 0.1307]]), tensor([[0.1635, 0.4716, 0.1728],
[0.8023, 0.9664, 0.4934]])]
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
下面例子说明torch.cat()与torch.stack()区别。可以看出,
- stack()是增加新的维度来完成拼接,不改变原维度上的数据大小。
- cat()是在现有维度上进行数据的增加(改变了现有维度大小),不增加新的维度。
x = torch.rand(2,3)
y = torch.rand(2,3)
print(torch.stack((x,y),1).size())
print(torch.cat((x,y),1).size())
output:
torch.Size([2, 2, 3])
torch.Size([2, 6])