pytorch nn.Unflatten 和 nn.Flatten模块介绍

nn.Flatten 和 nn.Unflatten 是 PyTorch 中用于调整张量形状的模块。它们提供了对多维张量的简单变换,常用于神经网络模型的层之间的数据调整。


1. nn.Flatten

功能:

  • 将输入张量展平为二维张量,通常用于将卷积层的输出展平成全连接层的输入。
  • 它会将张量的指定维度范围压缩为单个维度。

构造参数:

  • start_dim: 展平的起始维度(默认值为 1)。
  • end_dim: 展平的结束维度(默认值为 -1)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, channels, height, width]
x = torch.randn(4, 3, 32, 32)

# 展平操作
flatten = nn.Flatten(start_dim=1)  # 从维度1到最后展平
y = flatten(x)

print(y.shape)  # 输出: [4, 3072] (3*32*32 被展平)

适用场景:

  • 通常用于从卷积层(或其他多维特征)到全连接层的过渡。
  • 例如:[batch_size, channels, height, width] -> [batch_size, features]

2. nn.Unflatten

功能:

  • 将展平的张量还原为多维张量。
  • 它通过指定目标维度和形状信息,反向操作 nn.Flatten

构造参数:

  • dim: 需要展开的维度。
  • unflattened_size: 展开的形状(tuple 类型)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, features]
x = torch.randn(4, 3072)

# 还原操作
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))
y = unflatten(x)

print(y.shape)  # 输出: [4, 3, 32, 32]

适用场景:

  • 通常用于从全连接层(或展平特征)还原到卷积层或其他多维表示。
  • 例如:[batch_size, features] -> [batch_size, channels, height, width]

对比

特性nn.Flattennn.Unflatten
主要操作将多个维度压缩为一个维度将一个维度展开为多个维度
输入多维张量展平的张量
输出二维张量恢复为多维张量
常用场景用于连接卷积层和全连接层用于从展平的特征恢复到多维结构
参数控制指定展平的起始和结束维度范围指定需要展开的维度和目标形状

实际应用示例

结合使用 Flatten 和 Unflatten:

import torch
from torch import nn

# 初始化 Flatten 和 Unflatten
flatten = nn.Flatten(start_dim=1)
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))

# 模拟数据
x = torch.randn(4, 3, 32, 32)  # [batch_size, channels, height, width]

# 展平
flat_x = flatten(x)
print(flat_x.shape)  # 输出: [4, 3072]

# 恢复
unflat_x = unflatten(flat_x)
print(unflat_x.shape)  # 输出: [4, 3, 32, 32]

这两个模块通过简单的接口提供了灵活的形状调整功能,是构建神经网络过程中不可或缺的工具。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值