理解张量拼接(torch.cat)

拼接

维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。 选择一个dim进行拼接的时候其他两个维度大小要相等
![[Pasted image 20240808214248.png]]

对于三维张量,理解 torch.catdim 参数确实变得更加抽象,但原理是相同的。让我们通过一个具体的例子来说明这一点。

import torch

# 创建两个 3D 张量
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

print("Tensor a shape:", a.shape)
print(a)
print("\nTensor b shape:", b.shape)
print(b)

# dim=0 连接
c_dim0 = torch.cat([a, b], dim=0)
print("\nResult of torch.cat([a, b], dim=0):")
print("Shape:", c_dim0.shape)
print(c_dim0)

# dim=1 连接
c_dim1 = torch.cat([a, b], dim=1)
print("\nResult of torch.cat([a, b], dim=1):")
print("Shape:", c_dim1.shape)
print(c_dim1)

# dim=2 连接
c_dim2 = torch.cat([a, b], dim=2)
print("\nResult of torch.cat([a, b], dim=2):")
print("Shape:", c_dim2.shape)
print(c_dim2)

现在让我们详细解释这个三维张量的例子:

  1. 初始张量:

    • ab 都是形状为 (2, 2, 2) 的 3D 张量。
    • 可以将它们想象成两个 2x2 的矩阵堆叠在一起。
  2. dim=0 连接:

    • 结果形状:(4, 2, 2)
    • 这相当于在第一个维度上堆叠张量。
    • 可以理解为将 b 放在 a 的"下面",增加了第一个维度的大小。
  3. dim=1 连接:

    • 结果形状:(2, 4, 2)
    • 这相当于在第二个维度上堆叠张量。
    • 可以理解为在每个 2x2 矩阵的"行"方向上扩展,将 b 的行添加到 a 的每个对应部分的下方。
  4. dim=2 连接:

    • 结果形状:(2, 2, 4)
    • 这相当于在第三个维度(最内层)上堆叠张量。
    • 可以理解为在每个 2x2 矩阵的"列"方向上扩展,将 b 的列添加到 a 的每个对应部分的右侧。

理解三维张量 torch.cat 的关键点:

  1. 维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。

  2. dim=0:增加"深度"或"批次"的数量。

  3. dim=1:增加每个"深度"层或"批次"中的行数。

  4. dim=2:增加每行中的元素数量(列数)。

  5. 保持其他维度:除了被连接的维度,其他维度的大小保持不变。

  6. 形状变化:只有指定的 dim 对应的维度大小会改变(增加),其他维度大小保持不变。

  7. 一致性:要连接的张量在非连接维度上的大小必须相同。

3D Matrix Visualization

Let’s visualize the 3D matrices a and b, and their concatenation results.

Matrix a (2x2x2):
Depth 0:    Depth 1:
+---+---+   +---+---+
| 1 | 2 |   | 5 | 6 |
+---+---+   +---+---+
| 3 | 4 |   | 7 | 8 |
+---+---+   +---+---+
Matrix b (2x2x2):
Depth 0:    Depth 1:
+----+----+ +----+----+
| 9  | 10 | | 13 | 14 |
+----+----+ +----+----+
| 11 | 12 | | 15 | 16 |
+----+----+ +----+----+

Concatenation Results:

dim=0 (4x2x2):
Depth 0:    Depth 1:    Depth 2:    Depth 3:
+---+---+   +---+---+   +----+----+ +----+----+
| 1 | 2 |   | 5 | 6 |   | 9  | 10 | | 13 | 14 |
+---+---+   +---+---+   +----+----+ +----+----+
| 3 | 4 |   | 7 | 8 |   | 11 | 12 | | 15 | 16 |
+---+---+   +---+---+   +----+----+ +----+----+
dim=1 (2x4x2):
Depth 0:        Depth 1:
+---+---+       +---+---+
| 1 | 2 |       | 5 | 6 |
+---+---+       +---+---+
| 3 | 4 |       | 7 | 8 |
+---+---+       +---+---+
| 9 | 10 |      | 13| 14|
+---+---+       +---+---+
| 11| 12 |      | 15| 16|
+---+---+       +---+---+
dim=2 (2x2x4):
Depth 0:        Depth 1:
+---+---+---+---+   +---+---+---+---+
| 1 | 2 | 9 | 10|   | 5 | 6 | 13| 14|
+---+---+---+---+   +---+---+---+---+
| 3 | 4 | 11| 12|   | 7 | 8 | 15| 16|
+---+---+---+---+   +---+---+---+---+

当然可以!让我们通过具体的例子来形象地解释不同维度上的拼接。

定义张量

首先,定义三个张量 x, y, z,它们分别具有如下形状:

  • x 的形状是 [2, 1, 3]
  • y 的形状是 [2, 3, 3]
  • z 的形状是 [2, 2, 3]
import torch

x = torch.tensor([[[0, 0, 0]], [[0, 0, 0]]])
y = torch.tensor([
    [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
    [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
])
z = torch.tensor([
    [[0, 0, 0], [0, 0, 0]],
    [[0, 0, 0], [0, 0, 0]]
])

(1) 在 dim=0 上拼接

dim=0 上拼接,相当于增加“深度”或“批次”的数量。每个张量的“深度”都会堆叠起来。

w_dim0 = torch.cat([x, y, z], dim=0)
print(w_dim0.shape)

形象解释

x:
[
 [[0, 0, 0]],  # 第一层深度
 [[0, 0, 0]]   # 第二层深度
]

y:
[
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度
]

z:
[
 [[0, 0, 0], [0, 0, 0]],  # 第一层深度
 [[0, 0, 0], [0, 0, 0]]   # 第二层深度
]

拼接结果 w_dim0:
[
 [[0, 0, 0]],  # x 第一层深度
 [[0, 0, 0]],  # x 第二层深度
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # y 第一层深度
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # y 第二层深度
 [[0, 0, 0], [0, 0, 0]],  # z 第一层深度
 [[0, 0, 0], [0, 0, 0]]   # z 第二层深度
]

形状:[6, 3, 3]

(2)dim=1 上拼接

dim=1 上拼接,相当于增加每个“深度”层中的行数。每个深度层的行数会拼接起来。

w_dim1 = torch.cat([x, y, z], dim=1)
print(w_dim1.shape)

形象解释

x:
[
 [[0, 0, 0]],  # 第一层深度的第一行
 [[0, 0, 0]]   # 第二层深度的第一行
]

y:
[
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度的三行
 [[0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度的三行
]

z:
[
 [[0, 0, 0], [0, 0, 0]],  # 第一层深度的两行
 [[0, 0, 0], [0, 0, 0]]   # 第二层深度的两行
]

拼接结果 w_dim1:
[
 [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度的六行
 [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度的六行
]

形状:[2, 6, 3]

当然可以!为了展示如何在 dim=2(第三个维度)上拼接张量,我们需要确保这些张量在前两个维度上的大小是相同的,而在第三个维度上的大小可以不同。

假设我们定义三个张量 a, b, c,它们分别具有如下形状:

  • a 的形状是 [2, 2, 2]
  • b 的形状是 [2, 2, 3]
  • c 的形状是 [2, 2, 1]
import torch

a = torch.tensor([
    [[1, 2], [3, 4]],
    [[5, 6], [7, 8]]
])

b = torch.tensor([
    [[9, 10, 11], [12, 13, 14]],
    [[15, 16, 17], [18, 19, 20]]
])

c = torch.tensor([
    [[21], [22]],
    [[23], [24]]
])

(3)在 dim=2 上拼接

dim=2 上拼接,相当于增加每行中的元素数量(列数)。每个深度层中的列数会拼接起来:

w_dim2 = torch.cat([a, b, c], dim=2)
print(w_dim2)
print(w_dim2.shape)

形象解释

a:
[
 [[1, 2],      [3, 4]],     # 第一层深度的两行两列
 [[5, 6],      [7, 8]]      # 第二层深度的两行两列
]

b:
[
 [[9, 10, 11], [12, 13, 14]], # 第一层深度的两行三列
 [[15, 16, 17], [18, 19, 20]] # 第二层深度的两行三列
]

c:
[
 [[21],        [22]],       # 第一层深度的两行一列
 [[23],        [24]]        # 第二层深度的两行一列
]

拼接结果 w_dim2:
[
 [[1, 2, 9, 10, 11, 21], [3, 4, 12, 13, 14, 22]],       # 第一层深度的两行六列
 [[5, 6, 15, 16, 17, 23], [7, 8, 18, 19, 20, 24]]       # 第二层深度的两行六列
]

w_dim2 的形状为:[2, 2, 6]

通过在 dim=2 上拼接,结果张量 w_dim2 的第三个维度是各个张量第三个维度的和:2 + 3 + 1 = 6

# 代码输出:
# tensor([[[ 1,  2,  9, 10, 11, 21],
#          [ 3,  4, 12, 13, 14, 22]],
# 
#         [[ 5,  6, 15, 16, 17, 23],
#          [ 7,  8, 18, 19, 20, 24]]])
# 
# 形状: torch.Size([2, 2, 6])

希望这个例子能帮助你更好地理解如何在 dim=2 上拼接张量。
非常好的问题!让我们用书架的比喻来解释这个例子,这将有助于更直观地理解张量的维度。

在这个比喻中:

  • dim=0(第一个维度)代表书架的数量
  • dim=1(第二个维度)代表每个书架的层板数
  • dim=2(第三个维度)代表每个层板可以放置的书本数量(即层板的宽度)

让我们用这个比喻来解释 a, b, 和 c 这三个张量:

  1. 张量 a [2, 2, 2]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放2本书
  2. 张量 b [2, 2, 3]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放3本书
  3. 张量 c [2, 2, 1]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放1本书

当我们在 dim=2 上拼接这些张量时,相当于我们在不改变书架数量和层板数量的情况下,将每个层板变宽,使其可以容纳更多的书。

拼接后的结果 w_dim2 [2, 2, 6]:

  • 仍然是2个书架(dim=0 没变)
  • 每个书架仍然有2层层板(dim=1 没变)
  • 但是现在每个层板可以放6本书了(dim=2 变成了 2+3+1=6)

形象地说:

原来的书架 a:    原来的书架 b:    原来的书架 c:
[□□]            [□□□]           [□]
[□□]            [□□□]           [□]

[□□]            [□□□]           [□]
[□□]            [□□□]           [□]

拼接后的新书架 w_dim2:
[□□□□□□]  (2+3+1 = 6本书)
[□□□□□□]

[□□□□□□]
[□□□□□□]

每个 □ 代表一本书(或者说张量中的一个元素)。

这个比喻展示了我们如何在不增加书架数量(dim=0)或层板数量(dim=1)的情况下,通过拼接来增加每个层板可以放置的书本数量(dim=2)。这就是在 dim=2 上进行张量拼接的直观理解。

  • 12
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值