一、基本功能
pytroch官方文档对于这个函数的描述很简略。只有一句话:在维度上连接(concatenate)若干个张量。(这些张量形状相同)。
经过代码总结归纳,可以得到stack(tensors,dim=0,out=None)
函数的功能:
将若干个张量在dim维度上连接,生成一个扩维的张量,比如说原来你有若干个2维张量,连接可以得到一个3维的张量。
设待连接张量维度为n,dim取值范围为-n-1~n,这里得提一下为负的意义:-i为倒数第i个维度。举个例子,对于2维的待连接张量,-1维即3维,-2维即2维。
上代码:
a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[10,20,30],[40,50,60]])
c=torch.tensor([[100,200,300],[400,500,600]])
print(torch.stack([a,b,c],dim=0))
print(torch.stack([a,b,c],dim=1))
print(torch.stack([a,b,c],dim=2))
print(torch.stack([a,b,c],dim=0).size())
print(torch.stack([a,b,c],dim=1).size())
print(torch.stack([a,b,c],dim=2).size())
#输出结果为:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 10, 20, 30],