pytorch中的维度变换函数汇总

概述

PyTorch 提供了多种函数来执行张量的维度变换,这些操作对于深度学习模型的构建和数据预处理非常重要。下面是一些常用的与维度变换相关的函数:

1. torch.view

  • 用途:改变张量的形状而不改变其数据。
  • 说明:要求新形状的元素总数与原始张量相同。

2. torch.reshape

  • 用途:改变张量的形状而不改变其数据。
  • 说明:与 torch.view 类似,但 torch.reshape 可以处理不连续的张量。

3. torch.transpose

  • 用途:交换张量的两个维度。
  • 说明:常用于交换矩阵的行和列。

4. torch.permute

  • 用途:更一般地重新排列张量的维度。
  • 说明:可以一次性交换多个维度。

5. torch.squeeze

  • 用途:移除张量中所有维度为1的维度。
  • 说明:可指定移除特定的单维度。

6. torch.unsqueeze

  • 用途:在指定位置增加一个维度为1的维度。
  • 说明:常用于增加批量大小的维度。

7. torch.flatten

  • 用途:将张量展平为一维。
  • 说明:常用于连接卷积层和全连接层。

8. torch.chunk

  • 用途:将张量分割成特定数量的块。
  • 说明:沿指定的维度进行分割。

9. torch.split

  • 用途:将张量分割成特定大小的块。
  • 说明:可以指定每个块的大小,沿指定的维度进行分割。

10. torch.cat

  • 用途:将一系列张量按指定的维度连接。
  • 说明:所有张量的除了连接维度外,其他维度必须相同。

11. torch.stack

  • 用途:将一系列张量沿新的维度堆叠。
  • 说明:所有张量的形状必须完全相同。

这些函数提供了强大的工具来操作和变换张量的形状,是深度学习模型设计和实现中不可或缺的部分。在实际应用中,合理使用这些函数可以使数据处理和模型构建更加高效和灵活。

torch.view()详解

torch.view() 是 PyTorch 中用于改变张量形状的一个非常重要且常用的函数。它允许用户在保持张量数据不变的前提下,重新定义张量的形状。这个操作非常有用,尤其是在需要调整网络层输入输出张量形状的场景中。

基本用法

tensor.view(*shape)
  • tensor 是需要被重新形状的原始张量。
  • shape 是一个整数序列,定义了期望的输出张量形状。形状参数可以是多个整数,也可以是一个元组或列表。

注意事项

  1. 总元素个数不变:使用 view() 时,新形状的总元素个数必须与原始张量中的总元素个数相同。换句话说,只能改变形状,不能改变数据的总量。
  2. 连续性view() 要求原始张量在内存中是连续的。如果原始张量不是连续的,可能需要先调用 .contiguous() 方法来使张量连续。张量的连续性可以通过 .is_contiguous() 方法检查。

例子

基本示例
import torch

# 创建一个1x2x3的张量
x = torch.randn(1, 2, 3)
print(x)
print(x.shape)

# 使用view改变张量形状
y = x.view(2, 3)
print(y)
print(y.shape)

# 改变成3x2形状
z = x.view(3, 2)
print(z)
print(z.shape)

这个例子展示了如何将一个形状为 (1, 2, 3) 的张量重新形状为 (2, 3)(3, 2)

使用-1自动推断维度

在调用 view() 时,你可以在 shape 参数中使用 -1。PyTorch 将自动计算这一维度的大小。

x = torch.randn(4, 4)
print(x.shape)

# 使用-1自动推断维度
y = x.view(-1, 8)
print(y.shape)

z = x.view(8, -1)
print(z.shape)

在这个例子中,x.view(-1, 8) 表示我们希望得到一个第二维度为 8 的张量,PyTorch 自动计算第一维度应该是 2。同理,x.view(8, -1) 表示第一维度为 8,PyTorch 自动计算第二维度为 2。

总结

torch.view() 是一个非常有用的函数,可以灵活地改变张量的形状而不改变其数据。在实际应用中,正确地使用 view() 可以帮助我们更高效地设计和实现深度学习模型。记住在使用 view() 之前确保张量是连续的,或者在需要时使用 .contiguous() 来使其连续。

torch.reshape()详解

torch.reshape() 函数在 PyTorch 中用于改变张量的形状。与 torch.view() 类似,torch.reshape() 允许张量以不同的形状呈现,而不改变其底层数据。不过,与 torch.view() 的要求张量在内存中连续不同,torch.reshape() 可以作用于不连续的张量,并且如果可能,它会返回一个视图(view),否则会返回一个拷贝。

基本用法

torch.reshape(input, shape)
  • input 是需要被重新形状的原始张量。
  • shape 是定义新张量形状的整数或者由整数组成的元组或列表。

注意事项

  • 总元素个数不变:与 view() 类似,使用 reshape() 时,新形状的总元素个数必须与原始张量中的总元素个数相同。
  • 可能不连续:如果原始张量不连续,reshape() 可以处理,并且可能返回一个数据的拷贝,而不是视图。

例子

基本示例
import torch

# 创建一个1x2x3的张量
x = torch.randn(1, 2, 3)
print(x)
print(x.shape)

# 使用reshape改变张量形状
y = torch.reshape(x, (2, 3))
print(y)
print(y.shape)

# 改变成3x2形状
z = torch.reshape(x, (3, 2))
print(z)
print(z.shape)

在这个例子中,我们创建了一个形状为 (1, 2, 3) 的张量,并使用 torch.reshape() 将其形状改变为 (2, 3)(3, 2)

使用-1自动推断维度

view() 类似,reshape() 也支持使用 -1shape 参数中,PyTorch 将自动计算这一维度的大小。

x = torch.randn(4, 4)
print(x.shape)

# 使用-1自动推断维度
y = torch.reshape(x, (-1, 8))
print(y.shape)

z = torch.reshape(x, (8, -1))
print(z.shape)

在这个例子中,torch.reshape(x, (-1, 8)) 表示我们希望得到一个第二维度为 8 的张量,PyTorch 自动计算第一维度应该是 2。同理,torch.reshape(x, (8, -1)) 表示第一维度为 8,PyTorch 自动计算第二维度为 2。

总结

torch.reshape() 是一个强大的函数,它提供了灵活的张量形状改变能力。在实际应用中,reshape()view() 都非常有用,但是如果你不确定张量是否连续,或者你需要确保得到的结果不依赖于输入数据的内存布局,那么使用 reshape() 会是一个更安全的选择。

reshape就能完全替代view了么?

在许多情况下,torch.reshape()torch.view() 都可以用来改变张量的形状,但它们之间存在一些差异,这些差异决定了在特定情况下哪一个是更好的选择。理解这些差异有助于在实际应用中做出更合适的决策。

torch.view()

  • 连续性要求torch.view() 要求原始数据在内存中是连续的。如果原始数据不是连续的,你需要先调用 .contiguous() 方法来使数据连续,然后才能使用 view()
  • 性能优势:由于 view() 操作依赖于数据的连续性,它通常能够避免复制数据,直接返回原始数据的一个新视图(view)。这意味着 view() 在没有必要复制数据时,通常性能更好。

torch.reshape()

  • 灵活性torch.reshape() 可以处理不连续的数据,如果需要,它会自动处理数据的复制。这提供了更高的灵活性,因为你不需要担心数据是否连续。
  • 通用性:由于 reshape() 不要求数据连续,你可以在更多的场景下使用它,而不需要担心是否需要先调用 .contiguous()

使用建议

  • 当你确定数据是连续的,或者对性能有极致要求时,使用 torch.view() 可能是更好的选择。这是因为 view() 不会引入数据复制的开销,可以保持操作的高效性。
  • 在不确定数据连续性,或者希望代码具有更好的通用性和简洁性时,使用 torch.reshape() 是更安全、更方便的选择。它自动处理数据连续性的问题,使得代码更容易理解和维护。

结论

虽然在很多情况下你可以选择任意一个函数而不影响结果,理解它们的差异和适用场景可以帮助你更好地优化代码性能和可读性。在实践中,如果你的操作不频繁,或者对性能的影响不是特别关键,那么使用 torch.reshape() 会是一个更加方便和安全的选择,因为它减少了必须考虑的细节数量。然而,如果你在性能敏感的应用中工作,了解并正确使用 view()reshape() 将有助于你更有效地利用资源。

torch.transpose()详解

torch.transpose() 函数在 PyTorch 中用于交换张量的两个维度。这个函数非常有用,特别是当你需要改变多维数组的数据布局时。下面是 torch.transpose() 的基本用法和一些示例。

基本用法

函数原型如下:

torch.transpose(input, dim0, dim1)
  • input:要进行转置操作的输入张量。
  • dim0:要交换的第一个维度。
  • dim1:要交换的第二个维度。

torch.transpose() 返回一个新的张量,它是输入张量在指定的两个维度上交换后的结果。这个操作对于处理矩阵转置(2D张量)非常直观,但也可以用于更高维度的张量。

示例

矩阵转置(2D张量)

假设我们有一个2x3的矩阵,我们想要转置它:

import torch

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 转置矩阵
x_transposed = torch.transpose(x, 0, 1)
print(x_transposed)

输出将是:

tensor([[1, 4],
        [2, 5],
        [3, 6]])

在这个例子中,dim0dim1 分别是0和1,表示我们交换了矩阵的行和列。

更高维度的张量

对于一个3维的张量,我们可以选择不同的维度进行交换。例如,如果我们有一个形状为 (2, 3, 4) 的张量,我们可以选择交换第一个和第三个维度:

x = torch.randn(2, 3, 4)
x_transposed = torch.transpose(x, 0, 2)
print(x_transposed.shape)

输出将是:

torch.Size([4, 3, 2])

在这个例子中,通过交换第一个和第三个维度,张量的形状从 (2, 3, 4) 变成了 (4, 3, 2)

注意事项

  • torch.transpose() 对于2D张量(矩阵)非常直观,它相当于矩阵的转置操作。
  • 在更高维度的张量中使用 torch.transpose() 时,需要明确你想要交换的维度。这可以用于在不同操作中重新排列数据的布局。
  • torch.transpose() 返回的是原始数据的一个新视图,如果可能的话。这意味着修改返回的张量也会影响原始张量,除非涉及到必要的数据复制。

torch.transpose() 是处理张量形状和布局的强大工具,理解它的工作原理和如何使用它可以帮助你更有效地处理数据和构建模型。

torch.permute()详解

torch.permute() 是 PyTorch 中的一个非常有用的函数,它允许你重新排列多维张量的维度。与 torch.transpose() 相比,后者只能交换两个维度,torch.permute() 提供了更高的灵活性,允许一次性重新排列多个维度。这在处理具有多个维度的数据时特别有用,例如在图像处理(通常是4D张量:批次大小、通道数、高度、宽度)或其他多维数据操作中。

基本用法

函数原型如下:

torch.permute(*dims)
  • *dims:一个维度的序列,指定了输出张量的维度排列顺序。

这个函数返回一个新的张量,它的维度按照给定的顺序重新排列。

示例

3D 张量

假设我们有一个形状为 (2, 3, 4) 的3D张量,我们想要重新排列它的维度顺序。

import torch

x = torch.randn(2, 3, 4)

# 将维度从 (2, 3, 4) 重新排列为 (3, 2, 4)
x_permuted = x.permute(1, 0, 2)
print(x_permuted.shape)

输出将是:

torch.Size([3, 2, 4])

在这个例子中,我们将第二个维度(索引为1)移动到了第一个位置,第一个维度(索引为0)移动到了第二个位置,而第三个维度保持不变。这改变了张量的形状,但没有改变其内部数据的总体布局。

4D 张量(例如图像数据)

在处理图像数据时,经常需要在不同的维度之间进行转换。假设我们有一个批次的图像数据,其形状为 (batch_size, channels, height, width),我们想要将其转换为 (batch_size, height, width, channels) 的形式,以便与某些特定的图像处理库一起使用。

x = torch.randn(10, 3, 28, 28)  # 假设有10张3通道的28x28像素的图像

# 从 (批次大小, 通道数, 高度, 宽度) 到 (批次大小, 高度, 宽度, 通道数)
x_permuted = x.permute(0, 2, 3, 1)
print(x_permuted.shape)

输出将是:

torch.Size([10, 28, 28, 3])

在这个例子中,我们通过 permute() 将通道数从第二个位置移动到了最后一个位置,同时保持批次大小不变,并适当地调整高度和宽度的位置。

注意事项

  • torch.permute() 返回的是原始数据的一个新视图,如果可能的话。这意味着,修改返回的张量也会影响原始张量,除非涉及到必要的数据复制。
  • 使用 permute() 时需要特别注意新的维度顺序,错误的顺序可能会导致数据被错误地解释。

torch.permute() 是处理高维张量时非常有用的工具,使得在复杂的数据转换和操作中具有很大的灵活性。理解并正确使用这个函数可以帮助你更有效地处理数据,特别是在需要对数据的维度进行复杂重排时。
是的,你可以完全用permute替代transpose!

torch.squeeze()详解

torch.squeeze() 是 PyTorch 中的一个函数,用于从张量中移除所有长度为1的维度。这个操作不会改变张量中的数据,只是改变了其形状,使得它在某些操作中更加方便使用。这个函数特别有用在你想要消除由于某些操作(比如 torch.unsqueeze())引入的单维度,或者在处理来自某些特定类型的数据时自动出现的单维度。

函数原型

torch.squeeze() 的基本函数原型如下:

torch.squeeze(input, dim=None, *, out=None) -> Tensor
  • input:输入的张量。
  • dim:可选参数,指定要压缩的维度。如果指定了维度,那么只有在该维度长度为1时,该维度才会被移除。如果没有指定 dim,那么将移除所有长度为1的维度。
  • out:可选参数,指定输出张量。大多数情况下不需要使用。

基本用法

移除所有单维度

如果你有一个张量,它的形状包含了一个或多个长度为1的维度,torch.squeeze() 可以移除所有这些单维度,简化张量的形状。

import torch

x = torch.zeros(1, 2, 3, 1, 4)
print(x.shape)  # 输出: torch.Size([1, 2, 3, 1, 4])

y = torch.squeeze(x)
print(y.shape)  # 输出: torch.Size([2, 3, 4])

在这个例子中,x 是一个形状为 [1, 2, 3, 1, 4] 的张量。使用 torch.squeeze(x) 后,所有长度为1的维度都被移除,得到一个形状为 [2, 3, 4] 的张量。

指定维度移除

你也可以指定一个特定的维度来移除。如果该维度的长度为1,则会被移除;如果长度不为1,则张量形状不变。

x = torch.zeros(1, 2, 1, 3)
print(x.shape)  # 输出: torch.Size([1, 2, 1, 3])

y = torch.squeeze(x, 0)
print(y.shape)  # 输出: torch.Size([2, 1, 3])

z = torch.squeeze(x, 2)
print(z.shape)  # 输出: torch.Size([1, 2, 3])

在这个例子中,x 的形状是 [1, 2, 1, 3]。当我们指定 dim=0 时,因为第0维长度为1,它被移除了,得到形状 [2, 1, 3]。当我们指定 dim=2 时,同样因为第2维长度为1,它也被移除了,得到形状 [1, 2, 3]

注意事项

  • torch.squeeze() 返回的是原始数据的一个视图,如果可能的话。这意味着,修改返回的张量也会影响原始张量,除非涉及到必要的数据复制。
  • 在使用 torch.squeeze() 时,特别注意不要无意中移除了不该移除的维度,尤其是当你的数据维度可能会变化时。

torch.squeeze() 是处理张量时非常有用的工具,可以帮助清理数据形状,使得后续的操作更加简洁明了。理解并正确使用这个函数可以帮助你更有效地处理数据。

torch.unsqueeze()详解

torch.unsqueeze() 是 PyTorch 中的一个重要函数,它的作用是在指定的维度上为输入张量增加一个维度,该维度的长度为1。这个操作常用于增加张量的维度,以满足特定操作的需要,比如在使用某些特定的层时需要额外的批处理维度,或者在进行张量拼接时需要张量维度完全一致。

函数原型

torch.unsqueeze() 的函数原型如下:

torch.unsqueeze(input, dim) -> Tensor
  • input:输入的张量。
  • dim:指定要插入新维度的位置。维度从0开始计数,也可以使用负数索引,表示从后向前数。

使用示例

添加新维度

假设有一个形状为 (3, 4) 的二维张量,我们想要在不同的位置添加一个新的维度:

import torch

# 创建一个形状为 [3, 4] 的张量
x = torch.randn(3, 4)
print("Original shape:", x.shape)

# 在第0维前面添加一个新的维度
y = torch.unsqueeze(x, 0)
print("Shape after unsqueeze at dim 0:", y.shape)

# 在第1维前面添加一个新的维度
z = torch.unsqueeze(x, 1)
print("Shape after unsqueeze at dim 1:", z.shape)

# 使用负数索引,在最后一个维度后面添加一个新的维度
w = torch.unsqueeze(x, -1)
print("Shape after unsqueeze at last dim:", w.shape)

输出示例:

Original shape: torch.Size([3, 4])
Shape after unsqueeze at dim 0: torch.Size([1, 3, 4])
Shape after unsqueeze at dim 1: torch.Size([3, 1, 4])
Shape after unsqueeze at last dim: torch.Size([3, 4, 1])

从上面的例子可以看出,torch.unsqueeze() 在不同的位置添加了新的维度,改变了张量的形状。

注意事项

  • 维度索引dim 参数支持负数索引,这在处理具有不确定维度数量的张量时非常有用。
  • 不改变数据torch.unsqueeze() 操作不会改变张量的数据,只是改变了张量的形状。
  • 返回新张量:此操作不会就地修改输入张量,而是返回一个新的张量。
  • 广播机制:增加的维度常用于满足张量操作的广播机制,使得不同形状的张量可以进行数学运算。

使用场景

torch.unsqueeze() 在实际应用中非常有用,比如:

  • 当你需要对数据进行批处理时,通常需要在数据的最前面添加一个批处理维度。
  • 在进行某些特定操作,如张量拼接(torch.cat)或广播时,需要输入张量具有相同的维度数。
  • 在使用某些特定的层或操作时,它们可能期望输入具有特定的维度(例如,卷积层通常期望输入具有四个维度:批处理大小、通道数、高度、宽度)。

通过灵活运用 torch.unsqueeze(),你可以轻松调整张量的形状,以满足各种操作和层的需求。

torch.flatten()详解

torch.flatten() 是 PyTorch 中的一个函数,用于将输入张量展平成一维。这个操作通常用于将多维数据准备用于全连接层或者在进行某些特定操作前简化数据结构。它可以让你指定展平操作的开始和结束维度,这提供了额外的灵活性,允许你展平张量的一部分,而不是整个张量。

函数原型

torch.flatten() 的基本函数原型如下:

torch.flatten(input, start_dim=0, end_dim=-1) -> Tensor
  • input:输入的张量。
  • start_dim:指定从哪个维度开始展平,默认为0,即从第一个维度开始。
  • end_dim:指定在哪个维度结束展平,默认为-1,即展平到最后一个维度。

使用示例

完全展平张量

当你需要将一个多维张量完全展平成一维张量时,可以使用 torch.flatten()

import torch

x = torch.randn(2, 3, 4)  # 创建一个形状为 [2, 3, 4] 的张量
print("Original shape:", x.shape)

flat_x = torch.flatten(x)
print("Flattened shape:", flat_x.shape)

输出示例:

Original shape: torch.Size([2, 3, 4])
Flattened shape: torch.Size([24])
局部展平张量

你也可以指定从哪个维度开始到哪个维度结束进行展平,这在你想保留部分结构时非常有用:

x = torch.randn(2, 3, 4, 5)  # 创建一个形状为 [2, 3, 4, 5] 的张量
print("Original shape:", x.shape)

# 从第1维开始,到最后一维结束进行展平
partial_flat_x = torch.flatten(x, start_dim=1)
print("Partially flattened shape:", partial_flat_x.shape)

输出示例:

Original shape: torch.Size([2, 3, 4, 5])
Partially flattened shape: torch.Size([2, 60])

在这个例子中,张量从第1维开始展平,保留了第0维的结构,这使得展平后的张量形状变为 [2, 60]

注意事项

  • 使用 torch.flatten() 时,原始张量的数据不会被修改,这个操作返回一个新的张量。
  • start_dimend_dim 参数提供了展平操作的灵活性,你可以根据需要选择保留张量的哪些维度。
  • 在进行模型设计时,torch.flatten() 经常用于卷积层输出和全连接层输入之间,以将多维特征图转换为一维特征向量。

通过合理使用 torch.flatten(),你可以轻松地在需要的时候调整张量的形状,这在深度学习模型的设计和实现中非常有用。

torch.chunk()详解

torch.chunk() 是 PyTorch 中的一个函数,用于将张量分割成特定数量的块。这个操作在处理数据或模型时非常有用,特别是当你需要将数据均匀分配到不同的批次或者并行处理单元时。

函数原型

torch.chunk() 的基本函数原型如下:

torch.chunk(input, chunks, dim=0) -> List of Tensors
  • input:输入的张量。
  • chunks:要分割成的块的数量。如果无法均匀分割,则最后一个块会小于其它块。
  • dim:沿着哪个维度进行分割,默认为0。

使用示例

分割张量

假设你有一个形状为 [10, 5] 的二维张量,你想要沿着第一个维度(dim=0)将它分割成5个块:

import torch

# 创建一个形状为 [10, 5] 的张量
x = torch.randn(10, 5)
print("Original shape:", x.shape)

# 将张量分割成5个块
chunks = torch.chunk(x, chunks=5, dim=0)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i} shape: {chunk.shape}")

输出示例:

Original shape: torch.Size([10, 5])
Chunk 0 shape: torch.Size([2, 5])
Chunk 1 shape: torch.Size([2, 5])
Chunk 2 shape: torch.Size([2, 5])
Chunk 3 shape: torch.Size([2, 5])
Chunk 4 shape: torch.Size([2, 5])

在这个例子中,原始张量被分割成5个形状为 [2, 5] 的块。因为原始张量沿着第一个维度的大小(10)能够被分割数量(5)整除,所以每个块的大小是相同的。

处理不能均匀分割的情况

如果张量在指定的维度上不能被均匀分割成指定数量的块,那么最后一个块的大小会小于其它块:

# 创建一个形状为 [10, 5] 的张量
x = torch.randn(10, 5)

# 尝试将张量分割成3个块
chunks = torch.chunk(x, chunks=3, dim=0)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i} shape: {chunk.shape}")

输出示例:

Chunk 0 shape: torch.Size([4, 5])
Chunk 1 shape: torch.Size([4, 5])
Chunk 2 shape: torch.Size([2, 5])

在这个例子中,因为原始张量无法均匀分割成3个块,所以前两个块的形状为 [4, 5],最后一个块的形状为 [2, 5]

注意事项

  • torch.chunk() 返回的是一个张量列表,每个元素是分割后的一个块。
  • 如果分割的块数 chunks 大于沿着分割维度的大小,那么超出的块将是空的张量。
  • 分割操作不会改变原始数据的内容,只是在指定维度上进行视图分割。

torch.chunk() 在数据预处理、批处理或模型并行处理时非常有用,它提供了一种灵活的方式来分割数据,以适应不同的处理需求。

torch.split()详解

torch.split() 函数是 PyTorch 中用于将张量分割成多个较小的张量的函数。你可以指定分割的大小或者分割成特定数量的块。

函数原型

torch.split() 的基本函数原型如下:

torch.split(tensor, split_size_or_sections, dim=0)
  • tensor: 要分割的张量。
  • split_size_or_sections: 如果是单个整数,它表示每个分割块的大小(除了最后一个块可能会更小)。如果是整数列表或元组,它表示每个块的具体大小。
  • dim: 要分割的维度,默认为0。

使用示例

使用单个整数分割

当你提供一个单一整数时,torch.split() 会尝试将张量分割成尽可能均匀的块,每个块的大小由该整数确定。

import torch

# 创建一个形状为 [10, 5] 的张量
x = torch.randn(10, 5)

# 沿着第一个维度分割成大小为3的块
chunks = torch.split(x, 3, dim=0)

# 输出每个块的形状
for chunk in chunks:
    print(chunk.shape)

这将输出三个形状为 [3, 5] 的张量和一个形状为 [1, 5] 的张量,因为 10 不能被 3 整除。

使用整数列表分割

你也可以通过提供一个整数列表来精确指定每个块的大小。

# 沿着第一个维度分割成指定大小的块
chunks = torch.split(x, [2, 4, 4], dim=0)

# 输出每个块的形状
for chunk in chunks:
    print(chunk.shape)

这将输出一个形状为 [2, 5] 的张量,两个形状为 [4, 5] 的张量。这里列表 [2, 4, 4] 中的数字之和必须等于被分割维度的大小,在这个例子中是 10

注意事项

  • 如果 split_size_or_sections 是单个整数,且不能整除张量在指定维度的大小,则最后一个块将包含剩余的元素,其大小可能小于指定的块大小。
  • 如果 split_size_or_sections 是一个列表或元组,则其元素之和必须等于张量在指定维度的大小。
  • 分割操作不会在内存中创建新的数据副本,返回的块仅是原始张量数据的视图。

torch.split() 在处理批量数据或将数据分割成多个部分进行独立处理时非常有用。通过合适地选择分割大小和维度,你可以灵活地控制分割的结果。

torch.cat()详解

torch.cat() 是 PyTorch 中用于将多个张量沿着指定的维度拼接起来的函数。这个操作对于组合来自不同来源但具有相同形状(除了被拼接的维度)的数据非常有用。

函数原型

torch.cat() 的基本函数原型如下:

torch.cat(tensors, dim=0, *, out=None) -> Tensor
  • tensors:一个张量序列,即你想要拼接的所有张量的列表或元组。
  • dim:一个整数,指定要沿着哪个维度进行拼接。默认为0,即第一个维度。
  • out:可选参数,用于指定输出的目标张量。

使用示例

基本拼接

假设你有两个形状为 [3, 4] 的二维张量,你想要沿着第一个维度(dim=0)将它们拼接起来:

import torch

# 创建两个形状为 [3, 4] 的张量
x = torch.randn(3, 4)
y = torch.randn(3, 4)

# 沿着第一个维度拼接它们
result = torch.cat((x, y), dim=0)
print(result.shape)  # 输出: torch.Size([6, 4])

在这个例子中,两个张量被拼接成了一个形状为 [6, 4] 的新张量。

沿着不同的维度拼接

你也可以沿着第二个维度(dim=1)拼接同样的张量:

# 沿着第二个维度拼接它们
result = torch.cat((x, y), dim=1)
print(result.shape)  # 输出: torch.Size([3, 8])

在这个例子中,两个张量被拼接成了一个形状为 [3, 8] 的新张量。

注意事项
  • 所有被拼接的张量,除了指定的拼接维度之外,其他维度的大小必须相同。
  • 拼接操作不会分配新的内存来存储结果,而是在原有张量数据的基础上创建一个新的视图。
  • 如果你需要沿着多个维度进行拼接,你需要多次调用 torch.cat()

高级用法

拼接不同维度的张量

如果你想要拼接不同维度的张量,你需要先使用其他操作(如 unsqueezeview)来调整它们的维度,使得除了拼接维度外,其他维度的大小相同。

# 创建一个形状为 [3, 4] 的张量和一个形状为 [3, 1] 的张量
x = torch.randn(3, 4)
y = torch.randn(3, 1)

# 通过扩展 y 的第二个维度来匹配 x 的维度
y_expanded = y.expand(-1, 4)

# 现在可以沿着第一个维度拼接它们
result = torch.cat((x, y_expanded), dim=0)
print(result.shape)  # 输出: torch.Size([6, 4])

在这个例子中,我们首先将 y 的第二个维度扩展到与 x 相同的大小,然后再进行拼接。

torch.cat() 是一个非常实用的函数,尤其是在处理批量数据或者组合来自多个源的特征时。它的使用需要确保除了拼接维度之外,其他维度的大小是一致的,这样才能保证拼接操作的正确性。

torch.stack()详解

torch.stack() 是 PyTorch 中的一个函数,用于将一系列的张量沿着一个新的维度堆叠起来。这些张量的大小必须完全相同。

函数原型

torch.stack() 的基本函数原型如下:

torch.stack(tensors, dim=0, *, out=None) -> Tensor
  • tensors: 一个张量的序列。可以是列表或元组,包含要堆叠的张量。
  • dim: 要插入的新维度的索引。默认为0。
  • out: 可选参数,用于指定输出的目标张量。

使用示例

堆叠张量

如果你有一系列形状相同的张量,你可以使用 torch.stack() 将它们堆叠成一个新的张量。

import torch

# 创建三个形状为 [2, 3] 的张量
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)

# 将它们沿着一个新的维度堆叠起来
stacked = torch.stack((x, y, z), dim=0)

print(stacked.shape)  # 输出: torch.Size([3, 2, 3])

在这个例子中,我们创建了一个新的三维张量,其形状为 [3, 2, 3]。新的维度被添加在最前面(索引为0),并且包含了原始的三个张量。

改变堆叠维度

你可以通过改变 dim 参数来控制新维度的插入位置。

# 将它们沿着第二个维度堆叠起来
stacked = torch.stack((x, y, z), dim=1)

print(stacked.shape)  # 输出: torch.Size([2, 3, 3])

在这个例子中,新的维度被添加在第二个位置(索引为1),导致形状变为 [2, 3, 3]

注意事项

  • torch.stack() 要求所有输入张量的形状完全相同。如果形状不同,会抛出错误。
  • 堆叠操作会创建一个新的张量,这意味着它会在内存中分配新的空间来存储结果。
  • 新维度的大小等于输入张量的数量。
  • torch.stack()torch.cat() 是不同的操作。torch.cat() 用于沿着已经存在的维度连接张量,而 torch.stack() 是在新维度上堆叠张量。

torch.stack() 在需要保持张量独立性的同时组合它们时非常有用,例如在创建批次数据时,每个张量是一个独立的样本,你可能希望将它们堆叠成一个批次。

  • 9
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值