torch.nn.functional.unfold 用法解读

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

参数

  • kernel_size,滑动窗口的大小
  • dilation,和空洞卷积中的dilation一样,控制滑动窗口取值的间隔
  • padding,控制输入张量四周padding的大小
  • stride,控制滑动窗口滑动的步长

输入输出维度

  • 输入维度为 (N, C, *),其中N代表batch_size,C为通道数,*代表任意的空间维度。 
  • 输出维度为 \((N,C\times \prod (kernel\_size),L)\),其中 \(\prod (kernel\_size)\) 表示滑动窗口的大小,\(L\) 表示滑动窗口在输入张量中可以滑动的所有位置的个数。\(L\) 的大小计算如下

这里 \(L\) 的计算方式和卷积中输出大小的计算是一模一样的

这个函数的作用就是通过一个滑动窗口遍历输入张量,在这个过程中,取出输入张量中滑动窗口对应位置的值,组成输出向量。

这个过程有点类似于卷积,卷积过程中滑动窗口就是卷积核,卷积核滑动到某个位置,取出输入张量中对应位置的值与该卷积核对应位置的值分别相乘最后相加得到单个值的输出。而unfold是滑动到某一位置,直接取出输入张量中对应位置的值,然后flatten成一个一维向量,放到输出张量的对应位置处。

比如下图中的一个输入张量,batch_size=1,shape=(1, C, W, H),我们想沿WxH维度取出一个个局部特征,就可以通过unfold来实现。

示例

如下代码所示,输入x是一个维度为(2, 5, 3, 4)的张量,其中2是batch_size,5是通道数,3,4代表spatial维度的大小,对应上面输入维度中的 *

kernel大小为(2, 3),unfold做的就是在spatial维度滑动该kernel对应的窗口,取出对应的值,注意这里滑动窗口时取窗口内所有通道的值。上面的 \(\prod (kernel\_size)=2\times 3=6\),窗口内共有2x3x5=30个值,对应输出维度中的 \(C\times \prod (kernel\_size)\)。然后把窗口内对应的block展平成一个一维向量,放到输出中。接下来因为默认stride=1,dilation=1,padding=0,因此(2, 3)大小的kernel在维度(3, 4)中两个方向都只能滑动一次,因此滑动窗口只有4个位置,对应上面的 \(L=4\)。

x = torch.randn(2, 5, 3, 4)
output = F.unfold(x, kernel_size=(2, 3))
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
print(output.size())
# torch.Size([2, 30, 4])

参考

Unfold — PyTorch 1.12 documentation

PyTorch中torch.nn.functional.unfold函数使用详解_咆哮的阿杰的博客-CSDN博客_pytorch unfold

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值