Pytorch中常见调整维度函数总结

  • flatten()

.flatten()函数用于将输入的多维数组(或张量)转换为一个一维数组(或张量)。在 PyTorch 中,可以使用 .flatten() 方法来将张量(可以是多维的)平铺为一个一维张量。例如,如果有一个二维张量 x,你可以使用 .flatten() 方法将其平铺为一个一维张量。示例代码如下:

import torch

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 将二维张量 x 平铺成一维张量
x_flattened = x.flatten()

print(x_flattened)

输出:

tensor([1, 2, 3, 4, 5, 6])
  • unsqueeze(dim)

增加维度,例如,.unsqueeze(2)的作用是在张量的第三维(从0开始计数)上增加一个维度。这通常用于在处理图像或音频数据时,需要将单通道数据转换为多通道数据。

例如,如果有一个形状为 (3, 4) 的张量,使用 unsqueeze(2) 后,它的形状将变为 (3, 4, 1)。这个操作可以用来扩展张量的维度,以便与其他张量进行运算或者满足网络模型的输入要求。

以下是一个示例代码

import torch

# 创建一个示例张量
tensor_2d = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

# 在第三维上增加一个维度
tensor_expanded = tensor_2d.unsqueeze(2)

# 打印结果
print(tensor_expanded.shape)  # 输出: torch.Size([3, 4, 1])
  •  squeeze()

在 PyTorch 中,.squeeze() 是一个用于删除维度大小为 1 的维度的方法。如果张量中存在维度大小为 1 的维度,.squeeze() 方法将会删除这些维度,从而得到一个形状更小的张量。

下面是一个例子,演示了如何使用 .squeeze() 方法:

import torch

# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)

# 使用 .squeeze() 删除维度大小为 1 的维度
x_squeezed = x.squeeze()

print(x.size())          # 输出: torch.Size([1, 3, 1, 4])
print(x_squeezed.size())  # 输出: torch.Size([3, 4])

squeeze也可删除指定维度,如将上述代码中改为x.squeeze(0)

import torch

# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)

# 使用 .squeeze() 删除维度大小为 1 的维度
x_squeezed = x.squeeze(0)

print(x.size())          # 输出: torch.Size([1, 3, 1, 4])
print(x_squeezed.size())  # 输出: torch.Size([3, 1, 4])
  • view()

在 PyTorch 中,.view 是一个用于改变张量形状的函数,它可以用来调整张量的维度和大小,但不会改变张量中元素的数量。

具体来说,.view 函数接受一个或多个参数作为新的形状,并返回一个具有新形状的张量,同时保持张量中的数据不变。这意味着,对于一个具有 12 个元素的张量,你可以将其视图调整为一个形状为 (3, 4) 的二维张量,或者是一个形状为 (2, 2, 3) 的三维张量,只要满足新形状的总元素数量等于原始张量的元素数量。

以下是一个简单示例代码:

import torch

# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5, 6])

# 使用.view改变张量形状
y = x.view(2, 3)  # 将张量视图改变为一个形状为 (2, 3) 的二维张量

print(y)
tensor([[1, 2, 3],
        [4, 5, 6]])

Transformer中的view应用:

x = x.view(*head_shape, self.heads, self.d_k)

*head_shape 表示 head_shape 是一个元组或列表,其中包含了一些维度信息,通过 * 号将其展开为单独的维度参数。

  • torch.flip() 

torch.flip(input, dims)是pytorch中的函数,用于对输入张量的指定维度进行翻转操作。具体来说,torch.flip(input, dims)函数会返回一个新的张量,其中的元素按照指定维度进行翻转。翻转后的元素顺序与原始张量在指定维度上的顺序相反。

input是输入张量,dims是一个整数或整数列表,表示要进行翻转的维度。如果dims是一个整数,则只对该维度进行翻转;如果dims是一个整数列表,则对列表中的所有维度进行翻转。

import torch
 
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
 
result1 = torch.flip(x, dims=0)
result2 = torch.flip(x, dims=[0,1])
 
print(result1)
print(result2)
 
"""
result1 = tensor([[7, 8, 9],
                  [4, 5, 6],
                  [1, 2, 3]])
result2 = tensor([[9, 8, 7],
                  [6, 5, 4],
                  [3, 2, 1]])
"""
  • torch.stack()

https://blog.csdn.net/flyingluohaipeng/article/details/125034358?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171394866216800184160170%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=171394866216800184160170&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_click~default-2-125034358-null-null.142^v100^pc_search_result_base4&utm_term=torch.stack&spm=1018.2226.3001.4187 沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

  • 17
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值