在卷积网络中,经常会需要用到卷积核滑动窗口的操作,如下图所示。这个操作在大多数的深度学习框架中,都被封装的很好,以至于我们并不需要显式地调用便可以实现卷积网络的这个操作。
但是,大部分深度学习框架也是提供了显式地进行滑动窗口操作的API的,在pytorch
中就是unfold
和fold
。接下来我们来探讨下这两个函数的使用。
在pytorch
中,和unfold
有关的有:torch.nn.Unfold
, torch.nn.functional.unfold
,tensor.unfold
,其中我们讨论的是第一个,也就是作为网络层的Unfold
。
torch.nn.Unfold
按照官方的说法,就是从一个批次的输入样本中,提取出滑动的局部区域块,也即是实现所谓局部连接的滑动窗口操作。该类的构造器的参数有:
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
- 1
不难发现其参数和卷积nn.Conv2d
有着很多相似的地方,首先都需要指定卷积核的尺寸,腐蚀大小,填充大小以及步进大小等,实际上,完全可以通过扩展Unfold
实现Conv
操作。(conv = unfold + matmul + fold
)
我们来看下unfold
的输入和输出,其输入形状如:
(
N
,
C
,
∗
)
(
N
,
C
,
∗
)
(
N
,
C
,
∗
)
(N,C,∗)(N,C,∗) (N,C,*)
(N,C,∗)(N,C,∗)(N,C,∗)C×kernel_size[0]×kernel_size[1]..., 将其拉直变为了输出形状的第二维。
给个代码实例:
inp = torch.randn(1, 3, 10, 12)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
# 这里的unfold是函数,而Unfold是层,使用差不多
print(inp_unf.shape)
输出为:torch.Size([1, 60, 56])
- 1
- 2
- 3
- 4
- 5
- 6
注意到到目前为止只是对原输入进行了一种特殊的reshape
操作而已, 并没有引入任何可学习的参数。
接下来如果是为了实现卷积操作,则还需要引入参数,主要是对每个block
的拉直结果进行加权,如果同一个图片的每个block
的参数都是相同的,那么称为参数共享,就是标准的卷积层;如果每个block
的参数都不一样,那么就不是参数共享的,此时一般称为局部连接层(Local connected layer)。接下来以参数共享的例子作为示例:
# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
(torch.nn.functional.conv2d(inp, w) - out).abs().max()
tensor(1.9073e-06)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
我们看到最后的根据间接实现的卷积操作和自带的卷积操作的相减结果最大不超过1e-6
,可以视为是计算误差,因此是等价的。
这里我们刚才的例子中涉及到了fold
,这个操作和unfold
相反,是把一系列的滑动区块拼接成一个张量。其参数列表为:
torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)
- 1
其中需要的输入形状为: ( N , C × ∏ ( k e r n e l s i z e ) , L ) ( N , C × ∏ ( k e r n e l s i z e ) , L ) ( N , C × ∏ ( k e r n e l _ s i z e ) , L ) (N,C×∏(kernel_size),L)(N,C×∏(kernel_size),L) (N, C \times \prod(\mathrm{kernel\_size}) , L) (N,C×∏(kernelsize),L)(N,C×∏(kernelsize),L)(N,C×∏(kernel_size),L)(N,C,output_size[0],output_size[1]...)。
其中,不管是Unfold
还是fold
,其参数中的padding
和dilation
都是在进行reshape之前在原数据上操作的。
Q&A
- 这个函数在
tensorflow
中对应是tf.image.extract_image_patches
,即是从图片中取块,后续要自行进行卷积操作,也就是手动实现了滑动窗口的操作。[2]
tf.image.extract_image_patches(
images,
ksizes=None,
strides=None,
rates=None,
padding=None,
name=None,
sizes=None
)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
Reference
[1]. https://stackoverflow.com/questions/53972159/how-does-pytorchs-fold-and-unfold-work
[2]. https://www.tensorflow.org/api_docs/python/tf/image/extract_image_patches
</div>
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet">
<div class="more-toolbox">
<div class="left-toolbox">
<ul class="toolbox-list">
<li class="tool-item tool-active is-like "><a href="javascript:;"><svg class="icon" aria-hidden="true">
<use xlink:href="#csdnc-thumbsup"></use>
</svg><span class="name">点赞</span>
<span class="count">5</span>
</a></li>
<li class="tool-item tool-active is-collection "><a href="javascript:;" data-report-click="{"mod":"popu_824"}"><svg class="icon" aria-hidden="true">
<use xlink:href="#icon-csdnc-Collection-G"></use>
</svg><span class="name">收藏</span></a></li>
<li class="tool-item tool-active is-share"><a href="javascript:;"><svg class="icon" aria-hidden="true">
<use xlink:href="#icon-csdnc-fenxiang"></use>
</svg>分享</a></li>
<!--打赏开始-->
<!--打赏结束-->
<li class="tool-item tool-more">
<a>
<svg t="1575545411852" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5717" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><defs><style type="text/css"></style></defs><path d="M179.176 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5718"></path><path d="M509.684 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5719"></path><path d="M846.175 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5720"></path></svg>
</a>
<ul class="more-box">
<li class="item"><a class="article-report">文章举报</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="person-messagebox">
<div class="left-message"><a href="https://blog.csdn.net/LoseInVain">
<img src="https://profile.csdnimg.cn/4/6/8/3_loseinvain" class="avatar_pic" username="LoseInVain">
<img src="https://g.csdnimg.cn/static/user-reg-year/1x/4.png" class="user-years">
</a></div>
<div class="middle-message">
<div class="title"><span class="tit"><a href="https://blog.csdn.net/LoseInVain" data-report-click="{"mod":"popu_379"}" target="_blank">FesianXu</a></span>
</div>
<div class="text"><span>发布了126 篇原创文章</span> · <span>获赞 216</span> · <span>访问量 29万+</span></div>
</div>
<div class="right-message">
<a href="https://im.csdn.net/im/main.html?userName=LoseInVain" target="_blank" class="btn btn-sm btn-red-hollow bt-button personal-letter">私信
</a>
<a class="btn btn-sm bt-button personal-watch" data-report-click="{"mod":"popu_379"}">关注</a>
</div>
</div>
</div>