PyTorch中flatten() 函数的用法

一. 用法

Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。

二. 参数

      1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。

       2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。

三. 实例

 (1). 首先随机定义一个满足正态分布的(2,3,4)的数据x

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)

------------------------------------
tensor([[[ 0.1281,  1.6878,  0.2301, -0.0721],
         [ 1.2374, -0.6929,  1.1186,  0.4372],
         [ 0.5122,  1.4653, -0.1673,  0.7258]],

        [[ 0.2772, -1.9994, -1.2284,  0.2764],
         [-0.0451, -0.9195,  0.5749,  0.1942],
         [ 0.8539, -0.0434, -0.7313,  0.0234]]])
tensor([ 0.1281,  1.6878,  0.2301, -0.0721,  1.2374, -0.6929,  1.1186,  0.4372,
         0.5122,  1.4653, -0.1673,  0.7258,  0.2772, -1.9994, -1.2284,  0.2764,
        -0.0451, -0.9195,  0.5749,  0.1942,  0.8539, -0.0434, -0.7313,  0.0234])

此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。

 (2).

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(1)
print(x)

===========================================
tensor([[[-0.7137, -0.0859, -1.5284,  0.7284],
         [ 0.8425,  0.3606,  1.7639,  0.1848],
         [ 0.4040, -1.6575,  1.9134, -1.0787]],

        [[ 0.6981,  1.3494, -0.5817, -1.1824],
         [-0.4972,  0.4179,  2.1742, -0.2462],
         [ 0.2429, -1.9315, -0.3497,  0.7190]]])
tensor([[-0.7137, -0.0859, -1.5284,  0.7284,  0.8425,  0.3606,  1.7639,  0.1848,
          0.4040, -1.6575,  1.9134, -1.0787],
        [ 0.6981,  1.3494, -0.5817, -1.1824, -0.4972,  0.4179,  2.1742, -0.2462,
          0.2429, -1.9315, -0.3497,  0.7190]])

此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)

注意:start_dimend_dim参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim() 之间。

  • 7
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorchflatten函数用于将张量展平为一维。在引用的案例2,通过使用torch.flatten函数将一个3x3x3的张量展平为一维。可以看到,在start_dim=0和end_dim=1的情况下,张量的形状从(3, 3, 3)变为(9, 3)。在引用的案例3,通过使用torch.flatten函数将一个3x3x3的张量展平为一维。在start_dim=1和end_dim=2的情况下,张量的形状从(3, 3, 3)变为(3, 9)。最后,在引用的案例,如果将flatten函数改为reshape函数,将会得到不同的结果。具体来说,将图片张量reshape为形状为(1, 1, 1, -1)的张量,并通过这个张量传递给Tudui模型进行处理。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [PyTorch基础(15)-- torch.flatten()方法](https://blog.csdn.net/dongjinkun/article/details/121479361)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [深度学习(PyTorch)——flatten函数用法及其与reshape函数的区别](https://blog.csdn.net/qq_42233059/article/details/126663501)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值