torch.faltten()函数功能,当pytorch版本不支持faltten函数时,如何在不升级pytorch版本情况下实现faltten功能

torch.flatten()

#展平一个连续范围的维度,输出类型为Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
# Parameters:input (Tensor) – 输入为Tensor
#start_dim (int) – 展平的开始维度
#end_dim (int) – 展平的最后维度
#example
#一个3x2x2的三维张量
>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]],
                  [[9, 10],
                       [11, 12]]])
#当开始维度为0,最后维度为-1,展开为一维
>>> torch.flatten(t)
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
#当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
>>> torch.flatten(t, start_dim=1)
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
>>> torch.flatten(t, start_dim=1).size()
torch.Size([3, 4])
#下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
#前面的就会合并
>>> torch.flatten(t, start_dim=0, end_dim=1)
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
>>> torch.flatten(t, start_dim=0, end_dim=1).size()
torch.Size([6, 2])

上述转载自:https://blog.csdn.net/GhostintheCode/article/details/102530451

可以用x.view()函数实现上述功能

import torch
#展平一个连续范围的维度,输出类型为Tensor
t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]],
                  [[9, 10],
                   [11, 12]]])
t1 = torch.flatten(t, start_dim=1)
t2 = t.view([t.size()[0],-1])

#输出结果:
t = tensor([[[ 1,  2],
             [ 3,  4]],

            [[ 5,  6],
             [ 7,  8]],
  
            [[ 9, 10],
             [11, 12]]])

t1 = tensor([[ 1,  2,  3,  4],
             [ 5,  6,  7,  8],
             [ 9, 10, 11, 12]])

t2 = tensor([[ 1,  2,  3,  4],
             [ 5,  6,  7,  8],
             [ 9, 10, 11, 12]])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值