`Flatten(start_dim=1, end_dim=-1)` 是PyTorch中的一个函数,用于将输入张量进行扁平化操作。它可以将多维的张量转换为一维张量,保持数据的顺序不变。
参数:
- `start_dim`(可选):指定开始扁平化的维度。默认值为 1,表示从第二个维度开始扁平化。注意,维度索引是从 0 开始的。
- `end_dim`(可选):指定结束扁平化的维度。默认值为 -1,表示扁平化到最后一个维度。
返回值:
- 返回一个新的张量,是输入张量扁平化后的结果。
下面是一个示例,说明如何使用 `Flatten()` 函数:
import torch
input = torch.tensor([[1, 2, 3],
[4, 5, 6]])
output = torch.flatten(input, start_dim=0, end_dim=1)
print(output)
tensor([1, 2, 3, 4, 5, 6])
在上面的示例中,输入张量 `input` 是一个 2D 张量,形状为 (2, 3)。使用 `torch.flatten()` 函数对 `input` 进行扁平化操作,将其转换为一维张量。由于没有指定 `start_dim` 和 `end_dim`,默认从第二个维度(即行维度)开始扁平化,并扁平化到最后一个维度(即列维度)。最终的输出张量 `output` 是一个一维张量,包含了原始张量中的所有元素,按照原始张量的顺序排列。
请注意,`Flatten()` 函数返回的是一个新的张量,原始张量保持不变。