nn.Unfold()函数用法

nn.Unfold 是 PyTorch 中的一个函数,它的作用是从多维输入数据中提取局部块(patches)。这些局部块可以看作是输入数据的窗口视图,通常用于实现跨步卷积(stride convolution)或其他形式的局部特征提取。

函数定义

nn.Unfold(kernel_size) 的基本参数是 kernel_size,它定义了每个局部块的大小。可选参数包括:

  • stride:步长,控制局部块在输入数据上滑动的间隔。默认值为 1。
  • padding:填充,控制局部块边缘的零填充数量。默认值为 0。

工作原理

nn.Unfold 按照指定的 kernel_sizestridepadding 从输入数据中提取局部块,并将这些块展平为一维向量,然后按顺序堆叠起来形成输出张量。

输出张量的形状

输出张量的形状为 [N, C * kernel_size[0] * kernel_size[1], L],其中:

  • N 是批次大小(batch size)。
  • C 是输入数据的通道数。
  • L 是局部块的数量。

具体例子

假设我们有一个输入数据 x,其形状为 [1, 1, 5, 5],即一个批次大小为 1,通道数为 1,空间维度为 5x5 的张量。我们想要使用 nn.Unfold 来提取大小为 3x3、步长为 1、填充为 1 的局部块。

  1. 填充:首先,我们在输入数据的边缘添加 1 像素的零填充,得到一个形状为 [1, 1, 7, 7] 的张量。

  2. 滑动窗口:然后,我们使用一个 3x3 的窗口在填充后的图像上滑动,步长为 1。由于填充和步长,窗口可以覆盖整个图像。

  3. 提取局部块:每个 3x3 的局部块被提取出来,展平为一个包含 9 个元素的向量(3x3 = 9)。

  4. 计算局部块数量:在这个例子中,由于步长为 1,我们可以得到 5x5 个局部块(7 - 3 + 1 = 5)。

  5. 输出张量形状:最终,nn.Unfold 的输出形状将是 [1, 1 * 3 * 3, 5 * 5],即 [1, 9, 25]

import torch
import torch.nn as nn

# 创建输入张量 x
x = torch.randn(1, 1, 5, 5)

# 使用 nn.Unfold 提取局部块
unfold = nn.Unfold(kernel_size=(3, 3), stride=1, padding=1)
x_unfolded = unfold(x)

# 输出 x_unfolded 的形状
print(x_unfolded.shape)  # 输出: torch.Size([1, 9, 25])

在这个例子中,nn.Unfold 帮助我们从输入张量中有效地提取了所有 3x3 的局部块,并将它们转换为一个三维张量,这个张量可以用于后续的网络层处理。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值