- input: 输入,类型为Tensor。
- start_dim: 推平的起始维度。
- end_dim: 推平的结束维度。
import torch
a = torch.ones(2,3,4,5)
b = torch.flatten(a,start_dim=0,end_dim=2)
# 从0维开始往后推,推到第2维。所以最后应该是:(2*3*4,5)
print(b.shape)
b = torch.flatten(a,end_dim=2)
# 默认为0
print(b.shape)
b = torch.flatten(a,start_dim=-1)
# 从最后一维往后退,不变
print(b.shape)
b = torch.flatten(a,end_dim=-1)
# 推到最后一维,展平
print(b.shape)
Result:
torch.Size([24, 5])
torch.Size([24, 5])
torch.Size([2, 3, 4, 5])
torch.Size([120])