torch.flatten()
#展平一个连续范围的维度,输出类型为Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
# Parameters:input (Tensor) – 输入为Tensor
#start_dim (int) – 展平的开始维度
#end_dim (int) – 展平的最后维度
#example
#一个3x2x2的三维张量
>>> t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]])
#当开始维度为0,最后维度为-1,展开为一维
>>> torch.flatten(t)
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
#当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
>>> torch.flatten(t, start_dim=1)
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
>>> torch.flatten(t, start_dim=1).size()
torch.Size([3, 4])
#下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
#前面的就会合并
>>> torch.flatten(t, start_dim=0, end_dim=1)
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]])
>>> torch.flatten(t, start_dim=0, end_dim=1).size()
torch.Size([6, 2])
上述转载自:https://blog.csdn.net/GhostintheCode/article/details/102530451
可以用x.view()函数实现上述功能
import torch
#展平一个连续范围的维度,输出类型为Tensor
t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]])
t1 = torch.flatten(t, start_dim=1)
t2 = t.view([t.size()[0],-1])
#输出结果:
t = tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]])
t1 = tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
t2 = tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])