1. 张量形状的基本操作
- 查看张量形状:使用
tensor.shape
或tensor.size()
来查看张量的形状。 - Reshape:使用
tensor.view()
或tensor.reshape()
来改变张量的形状。需要注意的是,新的形状必须与原始形状的元素数量一致。
2. 张量形状的计算技巧
- 保持元素总数一致:在改变张量形状时,确保新形状的所有维度的乘积等于原形状的所有维度的乘积。例如,一个形状为
(2, 3)
的张量有 6 个元素,可以 reshape 为(3, 2)
或(6,)
,但不能 reshape 为(2, 2)
。 - 使用 -1 自动推断维度:在 reshape 时,可以使用
-1
来让 PyTorch 自动计算这一维度的大小。例如,tensor.view(-1, 3)
会根据总元素数和其他维度的大小推断出第一维度的大小。
3. 常见的形状变化操作
- 矩阵乘法:对两个矩阵进行乘法时,满足条件
A: (m, n)
和B: (n, p)
,结果矩阵形状为(m, p)
。 - 点积:两个向量的点积,如
A: (n,)
和B: (n,)
,结果为一个标量。 - 批处理:在深度学习中常见的操作是对多个样本进行批处理,通常会在张量的第一个维度上添加批次大小。例如,一个形状为
(batch_size, seq_length, embedding_dim)
的张量表示一个批次的嵌入序列。
4. 张量广播
广播是一种自动扩展较小张量形状以匹配较大张量形状的机制。在执行元素级操作时,广播机制会扩展较小的张量。
例如:
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape (2, 3)
b = torch.tensor([1, 2, 3]) # shape (3,)
result = a + b # b 会自动广播为 shape (2, 3)
print(result) # tensor([[2, 4, 6], [5, 7, 9]])
5. 示例代码解释
理解张量形状变化需要结合具体的代码。以下是一个示例,展示了张量形状的变化:
import torch
# 初始张量
x = torch.randn(2, 3, 4) # shape (2, 3, 4)
# 1. 查看张量形状
print("初始张量形状:", x.shape)
# 2. 改变张量形状
x_reshaped = x.view(6, 4) # shape (6, 4)
print("Reshape 后的形状:", x_reshaped.shape)
# 3. 交换维度
x_transposed = x.transpose(0, 1) # shape (3, 2, 4)
print("转置后的形状:", x_transposed.shape)
# 4. 使用 -1 自动推断维度
x_reshaped_auto = x.view(-1, 4) # shape (6, 4)
print("使用 -1 自动推断维度后的形状:", x_reshaped_auto.shape)
# 5. 广播机制
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape (2, 3)
b = torch.tensor([1, 2, 3]) # shape (3,)
result = a + b # b 会自动广播为 shape (2, 3)
print("广播机制结果:", result)
6. 实践练习
通过实践练习和调试代码,你可以更好地理解张量形状的变化。尝试不同的形状变化操作,并使用 print
语句输出每一步的张量形状,观察变化。