-
flatten()
.flatten()函数用于将输入的多维数组(或张量)转换为一个一维数组(或张量)。在 PyTorch 中,可以使用 .flatten()
方法来将张量(可以是多维的)平铺为一个一维张量。例如,如果有一个二维张量 x
,你可以使用 .flatten()
方法将其平铺为一个一维张量。示例代码如下:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 将二维张量 x 平铺成一维张量
x_flattened = x.flatten()
print(x_flattened)
输出:
tensor([1, 2, 3, 4, 5, 6])
-
unsqueeze(dim)
增加维度,例如,.unsqueeze(2)
的作用是在张量的第三维(从0开始计数)上增加一个维度。这通常用于在处理图像或音频数据时,需要将单通道数据转换为多通道数据。
例如,如果有一个形状为 (3, 4) 的张量,使用 unsqueeze(2)
后,它的形状将变为 (3, 4, 1)。这个操作可以用来扩展张量的维度,以便与其他张量进行运算或者满足网络模型的输入要求。
以下是一个示例代码
import torch
# 创建一个示例张量
tensor_2d = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 在第三维上增加一个维度
tensor_expanded = tensor_2d.unsqueeze(2)
# 打印结果
print(tensor_expanded.shape) # 输出: torch.Size([3, 4, 1])
-
squeeze()
在 PyTorch 中,.squeeze()
是一个用于删除维度大小为 1 的维度的方法。如果张量中存在维度大小为 1 的维度,.squeeze()
方法将会删除这些维度,从而得到一个形状更小的张量。
下面是一个例子,演示了如何使用 .squeeze()
方法:
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
# 使用 .squeeze() 删除维度大小为 1 的维度
x_squeezed = x.squeeze()
print(x.size()) # 输出: torch.Size([1, 3, 1, 4])
print(x_squeezed.size()) # 输出: torch.Size([3, 4])
squeeze也可删除指定维度,如将上述代码中改为x.squeeze(0)
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
# 使用 .squeeze() 删除维度大小为 1 的维度
x_squeezed = x.squeeze(0)
print(x.size()) # 输出: torch.Size([1, 3, 1, 4])
print(x_squeezed.size()) # 输出: torch.Size([3, 1, 4])
-
view()
在 PyTorch 中,.view
是一个用于改变张量形状的函数,它可以用来调整张量的维度和大小,但不会改变张量中元素的数量。
具体来说,.view
函数接受一个或多个参数作为新的形状,并返回一个具有新形状的张量,同时保持张量中的数据不变。这意味着,对于一个具有 12 个元素的张量,你可以将其视图调整为一个形状为 (3, 4) 的二维张量,或者是一个形状为 (2, 2, 3) 的三维张量,只要满足新形状的总元素数量等于原始张量的元素数量。
以下是一个简单示例代码:
import torch
# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 使用.view改变张量形状
y = x.view(2, 3) # 将张量视图改变为一个形状为 (2, 3) 的二维张量
print(y)
tensor([[1, 2, 3],
[4, 5, 6]])
Transformer中的view应用:
x = x.view(*head_shape, self.heads, self.d_k)
*head_shape
表示 head_shape 是一个元组或列表,其中包含了一些维度信息,通过 *
号将其展开为单独的维度参数。
-
torch.flip()
torch.flip(input, dims)是pytorch中的函数,用于对输入张量的指定维度进行翻转操作。具体来说,torch.flip(input, dims)函数会返回一个新的张量,其中的元素按照指定维度进行翻转。翻转后的元素顺序与原始张量在指定维度上的顺序相反。
input是输入张量,dims是一个整数或整数列表,表示要进行翻转的维度。如果dims是一个整数,则只对该维度进行翻转;如果dims是一个整数列表,则对列表中的所有维度进行翻转。
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result1 = torch.flip(x, dims=0)
result2 = torch.flip(x, dims=[0,1])
print(result1)
print(result2)
"""
result1 = tensor([[7, 8, 9],
[4, 5, 6],
[1, 2, 3]])
result2 = tensor([[9, 8, 7],
[6, 5, 4],
[3, 2, 1]])
"""
-
torch.stack()
https://blog.csdn.net/flyingluohaipeng/article/details/125034358?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171394866216800184160170%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=171394866216800184160170&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_click~default-2-125034358-null-null.142^v100^pc_search_result_base4&utm_term=torch.stack&spm=1018.2226.3001.4187 沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。