相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,flatten是Tensor类的一个重要方法,同时它也是一个torch模块中的一个函数,它们的语法如下所示。
Tensor.flatten(start_dim=0, end_dim=-1) → Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
input (Tensor) – the input tensor
start_dim (int) – the first dim to flatten
end_dim (int) – the last dim to flatten
flatten函数(或方法)用于将一个张量以特定方法展平, 如果传递了参数,则会将从start_dim到end_dim之间的维度展开。默认情况下,flatten将从第0维展平至最后1维。
flatten函数(或方法)可能返回原始张量、原始张量的视图(共享底层存储)或原始张量的副本:
- 如果没有维度被展平,则返回原始张量(同一个对象)。
- 如果输出张量可以视为等效地使用View展平,则返回视图(共享底层存储)。
- 如果输出张量不能视为等效地使用View展平,则返回数据副本。
可以查看View相关的文章,进行更加深入的了解,下面有三个例子分别说明这三种情况:
# 例1
import torch
input_tensor = torch.tensor([[1, 2], [3, 4]])
flattened_tensor = torch.flatten(input_tensor, start_dim=0, end_dim=0)
print(input_tensor)
print(flattened_tensor)
print(id(flattened_tensor) == id(input_tensor)) # 查看是否是同一个张量对象
print(flattened_tensor.storage().data_ptr() == input_tensor.storage().data_ptr()) # 查看是否共享底层存储
输出:
tensor([[1, 2],
[3, 4]])
tensor([[1, 2],
[3, 4]])
True
True
# 例2
import torch
input_tensor = torch.tensor([[1, 2], [3, 4]])
flattened_tensor = torch.flatten(input_tensor, start_dim=0, end_dim=1)
print(input_tensor)
print(flattened_tensor)
print(id(flattened_tensor) == id(input_tensor)) # 查看是否是同一个张量对象
print(flattened_tensor.storage().data_ptr() == input_tensor.storage().data_ptr()) # 查看是否共享底层存储
输出:
tensor([[1, 2],
[3, 4]])
tensor([1, 2, 3, 4])
False
True
# 例3
import torch
input_tensor = torch.tensor([[1, 2], [3, 4]]).transpose(0, 1)
flattened_tensor = torch.flatten(input_tensor, start_dim=0, end_dim=1)
print(input_tensor)
print(flattened_tensor)
print(id(flattened_tensor) == id(input_tensor)) # 查看是否是同一个张量对象
print(flattened_tensor.storage().data_ptr() == input_tensor.storage().data_ptr()) # 查看是否共享底层存储
输出:
tensor([[1, 3],
[2, 4]])
tensor([1, 3, 2, 4])
False
False