CARAFE: Content-Aware ReAssembly of FEatures
CARAFE: 轻量级通用上采样算子
其他上采样方法的不足
- 最近邻或者双线性上采样
仅通过像素点的空间位置来决定上采样核,并没有利用到特征图的语义信息,可以看作是一种“均匀”的上采样,而且感知域通常都很小(最近邻 1x1,双线性 2x2); - Deconvolution
上采样核并不是通过像素间的距离计算,而是通过网络学出来的,但对于特征图每个位置都是应用相同的上采样核,不能捕捉到特征图内容的信息,另外引入了大量参数和计算量,尤其是当上采样核尺寸较大的时候; - Dynamic filter
对于特征图每个位置都会预测一组不同的上采样核,但是参数量和计算量更加爆炸,而且公认比较难学习;
理想上采样算子的特性
- Large receptive field:需要具有较大的感受野,这样才能更好地利用周围的信息;
- Content-aware:上采样核应该和特征图的语义信息相关,基于输入内容进行上采样;
- Lightweight:轻量化,不能引入过多的参数和计算量;
CARAFE
CARAFE 分为两个主要模块,分别是上采样核预测模块和特征重组模块。假设上采样倍率为
σ
σ
σ
σσ \sigma
σσσkencoder=3,kup=5,(性能与计算量的折中)
class CARAFE(nn.Module):
def init(self, inC, outC, kernel_size=3, up_factor=2):
super(CARAFE, self).init()
self.kernel_size = kernel_size
self.up_factor = up_factor
self.down = nn.Conv2d(inC, inC // 4, 1)
self.encoder = nn.Conv2d(inC // 4, self.up_factor 2 * self.kernel_size 2,
self.kernel_size, 1, self.kernel_size // 2)
self.out = nn.Conv2d(inC, outC, 1)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> in_tensor<span class="token punctuation">)</span><span class="token punctuation">:</span>
N<span class="token punctuation">,</span> C<span class="token punctuation">,</span> H<span class="token punctuation">,</span> W <span class="token operator">=</span> in_tensor<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment"># N,C,H,W -> N,C,delta*H,delta*W</span>
<span class="token comment"># kernel prediction module</span>
kernel_tensor <span class="token operator">=</span> self<span class="token punctuation">.</span>down<span class="token punctuation">(</span>in_tensor<span class="token punctuation">)</span> <span class="token comment"># (N, Cm, H, W)</span>
kernel_tensor <span class="token operator">=</span> self<span class="token punctuation">.</span>encoder<span class="token punctuation">(</span>kernel_tensor<span class="token punctuation">)</span> <span class="token comment"># (N, S^2 * Kup^2, H, W)</span>
kernel_tensor <span class="token operator">=</span> F<span class="token punctuation">.</span>pixel_shuffle<span class="token punctuation">(</span>kernel_tensor<span class="token punctuation">,</span> self<span class="token punctuation">.</span>up_factor<span class="token punctuation">)</span> <span class="token comment"># (N, S^2 * Kup^2, H, W)->(N, Kup^2, S*H, S*W)</span>
kernel_tensor <span class="token operator">=</span> F<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>kernel_tensor<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># (N, Kup^2, S*H, S*W)</span>
kernel_tensor <span class="token operator">=</span> kernel_tensor<span class="token punctuation">.</span>unfold<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>up_factor<span class="token punctuation">,</span> step<span class="token operator">=</span>self<span class="token punctuation">.</span>up_factor<span class="token punctuation">)</span> <span class="token comment"># (N, Kup^2, H, W*S, S)</span>
kernel_tensor <span class="token operator">=</span> kernel_tensor<span class="token punctuation">.</span>unfold<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>up_factor<span class="token punctuation">,</span> step<span class="token operator">=</span>self<span class="token punctuation">.</span>up_factor<span class="token punctuation">)</span> <span class="token comment"># (N, Kup^2, H, W, S, S)</span>
kernel_tensor <span class="token operator">=</span> kernel_tensor<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>N<span class="token punctuation">,</span> self<span class="token punctuation">.</span>kernel_size <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">,</span> H<span class="token punctuation">,</span> W<span class="token punctuation">,</span> self<span class="token punctuation">.</span>up_factor <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># (N, Kup^2, H, W, S^2)</span>
kernel_tensor <span class="token operator">=</span> kernel_tensor<span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span> <span class="token comment"># (N, H, W, Kup^2, S^2)</span>
<span class="token comment"># content-aware reassembly module</span>
<span class="token comment"># tensor.unfold: dim, size, step</span>
in_tensor <span class="token operator">=</span> F<span class="token punctuation">.</span>pad<span class="token punctuation">(</span>in_tensor<span class="token punctuation">,</span> pad<span class="token operator">=</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>kernel_size <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>kernel_size <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">,</span>
self<span class="token punctuation">.</span>kernel_size <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>kernel_size <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
mode<span class="token operator">=</span><span class="token string">'constant'</span><span class="token punctuation">,</span> value<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (N, C, H+Kup//2+Kup//2, W+Kup//2+Kup//2)</span>
in_tensor <span class="token operator">=</span> in_tensor<span class="token punctuation">.</span>unfold<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>kernel_size<span class="token punctuation">,</span> step<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># (N, C, H, W+Kup//2+Kup//2, Kup)</span>
in_tensor <span class="token operator">=</span> in_tensor<span class="token punctuation">.</span>unfold<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>kernel_size<span class="token punctuation">,</span> step<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># (N, C, H, W, Kup, Kup)</span>
in_tensor <span class="token operator">=</span> in_tensor<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>N<span class="token punctuation">,</span> C<span class="token punctuation">,</span> H<span class="token punctuation">,</span> W<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># (N, C, H, W, Kup^2)</span>
in_tensor <span class="token operator">=</span> in_tensor<span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span> <span class="token comment"># (N, H, W, C, Kup^2)</span>
out_tensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>in_tensor<span class="token punctuation">,</span> kernel_tensor<span class="token punctuation">)</span> <span class="token comment"># (N, H, W, C, S^2)</span>
out_tensor <span class="token operator">=</span> out_tensor<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>N<span class="token punctuation">,</span> H<span class="token punctuation">,</span> W<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
out_tensor <span class="token operator">=</span> out_tensor<span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span>
out_tensor <span class="token operator">=</span> F<span class="token punctuation">.</span>pixel_shuffle<span class="token punctuation">(</span>out_tensor<span class="token punctuation">,</span> self<span class="token punctuation">.</span>up_factor<span class="token punctuation">)</span>
out_tensor <span class="token operator">=</span> self<span class="token punctuation">.</span>out<span class="token punctuation">(</span>out_tensor<span class="token punctuation">)</span>
<span class="token keyword">return</span> out_tensor
if name == ‘main’:
data = torch.rand(4, 20, 10, 10)
carafe = CARAFE(20, 10)
print(carafe(data).size())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
</div>
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
</div>