张量拼接:
在PyTorch中,张量拼接主要有两种方法:torch.cat()
和 torch.stack()
。这两种方法在功能上有所不同,但都用于将多个张量合并为一个张量。
1. torch.cat()
torch.cat(tensors, dim=0)
方法用于在给定维度上连接张量序列。这意味着所有被拼接的张量的形状(除了拼接的维度)必须相同。
-
参数:
tensors
:一个张量序列。dim
:要拼接的维度。
-
示例:
# 沿着行(维度0)拼接两个张量 x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6], [7, 8]]) result = torch.cat([x, y], dim=0) print(result) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]])
2.
torch.stack()
torch.stack(tensors, dim=0)
方法用于在新创建的维度上堆叠张量序列。这意味着所有被堆叠的张量必须具有完全相同的形状,torch.stack
添加的新维度将成为张量的一部分。 -
参数:
tensors
:一个张量序列。dim
:新维度插入的位置。
-
示例:
# 在新创建的维度上堆叠两个张量 x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) result = torch.stack([x, y], dim=0) print(result) # 输出: # tensor([[1, 2], # [3, 4]])
在
torch.cat
中,张量是沿着已存在的某个维度进行连接的,而torch.stack
则是在张量序列之间添加一个新的维度进行堆叠,因此堆叠后的张量维度会增加。张量切分
在 PyTorch 中,张量切分主要通过两种方法实现:torch.chunk()
和 torch.split()
。这两种方法都将一个张量分割成多个较小的张量,但它们在如何定义分割方式上有所不同。
1. torch.chunk()
torch.chunk(input, chunks, dim=0)
方法将张量分割成指定数量的块。如果原始张量在指定维度上不能均等分割成目标数量的块,则最后一个块会小于其他块。
-
参数:
input
:要分割的张量。chunks
:希望得到的块的数量。dim
:要分割的维度。
-
示例:
x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) chunks = torch.chunk(x, 2, dim=0) for i, chunk in enumerate(chunks): print(f"Chunk {i}:\n {chunk}") # 输出: # Chunk 0: # tensor([[1, 2], # [3, 4]]) # Chunk 1: # tensor([[5, 6], # [7, 8]])
2.
torch.split()
torch.split(tensor, split_size_or_sections, dim=0)
方法根据给定的大小分割张量。split_size_or_sections
可以是单个整数,指定每个分割块的大小,或者是一个整数列表,指定每个块的具体大小。 -
参数:
tensor
:要分割的张量。split_size_or_sections
:单个整数时,每个块的大小;整数列表时,每个块的具体大小。dim
:要分割的维度。
-
示例(使用单个大小值):
x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) splits = torch.split(x, 2, dim=0) for i, split in enumerate(splits): print(f"Split {i}:\n {split}") # 输出: # Split 0: # tensor([[1, 2], # [3, 4]]) # Split 1: # tensor([[5, 6], # [7, 8]])
示例(使用大小列表):
x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) splits = torch.split(x, [2, 1, 2], dim=0) for i, split in enumerate(splits): print(f"Split {i}:\n {split}") # 输出: # Split 0: # tensor([[1, 2], # [3, 4]]) # Split 1: # tensor([[5, 6]]) # Split 2: # tensor([[7, 8], # [9, 10]])
选择
torch.chunk()
还是torch.split()
取决于你的具体需求:如果你需要将张量分割成大小大致相等的几部分,使用torch.chunk()
;如果你需要精确控制每个分割块的大小,使用torch.split()
。
张量索引
在PyTorch中,张量索引允许你选择、修改、提取数据子集的功能非常强大且灵活。下面详细介绍张量索引的主要方法:
1. 基本索引和切片
与Python列表和NumPy数组相似,PyTorch张量也支持基本的索引和切片操作。
- 示例:
x = torch.arange(10) # 创建一个包含0到9的一维张量 print(x[2:5]) # 从位置2到4的元素 print(x[:5]) # 开始到位置4的元素 print(x[5:]) # 位置5到结束的元素 print(x[-1]) # 最后一个元素
2. 高级索引
高级索引允许你使用张量来选择数据。
-
整数数组索引:可以使用整数数组(列表或张量)进行索引,选择不连续的张量元素。
y = torch.arange(12).view(3, 4) # 创建一个形状为3x4的二维张量 print(y[[0, 2], [2, 3]]) # 选择(0,2)和(2,3)位置上的元素
布尔索引:使用布尔表达式可以选择符合特定条件的元素。
print(y[y > 5]) # 选择所有大于5的元素
3. 选择操作(Fancy Indexing)
Fancy Indexing允许你使用整数列表进行索引。
- 示例:
print(y[:, [1, 3]]) # 选择所有行的第1和第3列
4. 掩码操作
掩码操作允许你根据条件选择元素,这通常用于赋值或者提取子集。
mask = y > 5 # 创建一个布尔掩码
print(y[mask]) # 使用布尔掩码选择元素
5. 维度变换
PyTorch提供了许多操作来改变张量的形状,这虽然不严格是索引,但在处理张量时非常有用。
view
:返回一个新的张量,具有相同的数据但大小不同。
z = torch.arange(10)
print(z.view(2, 5)) # 重塑为2x5的张量
unsqueeze
:在指定位置添加一个维度。
print(z.unsqueeze(0)) # 在第0维添加一个维度
squeeze
:移除所有长度为1的维度。
print(z.unsqueeze(0).squeeze(0)) # 移除第0维的单维度
6. 转置操作
转置可以改变张量的形状和维度顺序。
-
t
:对二维张量进行转置。a = torch.arange(6).view(2, 3) print(a.t())
-
permute
:对多维张量进行维度交换。b = torch.arange(24).view(2, 3, 4) print(b.permute(2, 0, 1)) # 交换维度