torch.cat( )
官方:连接给定维数的序列张量。所有张量要么具有相同的形状(连接维度除外),要么为空。
例子:注意每个维度的shape
import torch
x = torch.randn(2, 3, 4)
rx1 = torch.cat((x, x, x), 0)
rx2 = torch.cat((x, x, x), 1)
print(x)
print(rx1)
print(rx2)
结果:
tensor([[[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929]],
[[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386]]])
tensor([[[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929]],
[[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386]],
[[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929]],
[[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386]],
[[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929]],
[[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386]]])
tensor([[[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929],
[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929],
[ 0.3334, 0.9134, 1.0101, -1.8516],
[ 1.2072, -0.2558, 1.1317, 1.6834],
[ 0.3684, 1.2080, -1.0300, 0.4929]],
[[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386],
[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386],
[-0.0713, 1.4526, -0.0669, -1.2307],
[ 0.2329, -0.6852, -0.3124, 0.0389],
[ 1.5162, -0.2998, 0.5022, 0.3386]]])
torch.Size([2, 3, 4])
torch.Size([6, 3, 4])
torch.Size([2, 9, 4])
沿0维拼接,0维度因为3个拼接故从2到6。沿1维拼接,1维度因为3个拼接故从3到9。
torch.stack( )
官方:沿着一个新的维度连接张量序列。所有张量的大小必须相同。
例子:注意每个维度的shape
import torch
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)
rx1 = torch.stack([x1, x2], dim=0)
rx2 = torch.stack([x1, x2], dim=1)
print(x1)
print(x2)
print(rx1)
print(rx2)
print(x1.shape)
print(x2.shape)
print(rx1.shape)
print(rx2.shape)
结果:
tensor([[ 0.3013, -0.2763, 1.1704],
[ 0.2154, 0.8657, -0.2266]])
tensor([[ 0.2167, -0.5974, 0.9155],
[-0.1918, -0.4248, 0.9019]])
tensor([[[ 0.3013, -0.2763, 1.1704],
[ 0.2154, 0.8657, -0.2266]],
[[ 0.2167, -0.5974, 0.9155],
[-0.1918, -0.4248, 0.9019]]])
tensor([[[ 0.3013, -0.2763, 1.1704],
[ 0.2167, -0.5974, 0.9155]],
[[ 0.2154, 0.8657, -0.2266],
[-0.1918, -0.4248, 0.9019]]])
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
Process finished with exit code 0
两个张量,沿0维拼接,产生0维,因为两个拼接故0维为2 。两个张量,沿1维拼接,产生1维,因为两个拼接故1维为2 。
区别:cat是沿维数直接拼接,stack是沿着维数产生维度再进行拼接。