SwinIR源码解读

Swinir源码解读

解读几个关键函数。

特征图到窗口,窗口到特征图

def window_partition(x, window_size):
    """
    特征图矩阵转窗口矩阵
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    窗口矩阵转特征图矩阵
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

很直接的维度拆分,但是为什么要按此步骤进行,目前看不出其中门道。想从特征图尺寸到分窗后尺寸有很多种做法,希望有大佬这里指点下迷津。
经过分窗口后,图像尺寸从 ( 1 , C , H , W 1,C,H,W 1,C,H,W) 变成 ( 1 , H / / w , W / / w , C 1,H//w,W//w,C 1,H//w,W//w,C)

window attention

class WindowAttention(nn.Module):
    ...
    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2] 
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
        ...

该函数的关键步骤
输入尺寸为 ( 窗 口 个 数 , w 2 , C ) (窗口个数, w^2, C) (,w2,C)(分窗口化后矩阵形态)
首先通过nn.linear将 C C C 扩展到 3 ∗ C 3*C 3C, 和经过一个1*1卷积的功效是一致的。
扩展后的矩阵reshape成 (3, 窗口数,head数,w2,C//窗口数)
接着将扩展后的矩阵拆分成qkv三个矩阵。
接着进行q@k,这一步的运算维度变化是:

(窗口数,head数,w2,C//窗口数)@(窗口数,head数,C//窗口数,w2
==> (窗口数,head数,w2,w2

然后是进行attn@V:

(窗口数,head数,w2,w2)@(窗口数,head数,w2,C//窗口数)
==> (窗口数,head数,w2,C//窗口数)

理解上面两步是很重要的。之前看到一篇论文claim基于QK生成的map是channel-invariant,spatial-invariant的,本次回溯也是出于探究这一点的正确性。
结论是:这一点的对错与head数有关

当head数=1时,qk生成的map与v相乘,等价于一个完全Depthwise的卷积。卷积核尺寸与窗口大小相当。
当head数=channel数时,qk生成的map与v相乘,等价于一个全卷积,尺寸仍与窗口大小相当。
当head数介于两者之间时,该操作等价于一个分组卷积,对几个head内的channel是共享参数的,但是对几个head间的channel是非共享的,尺寸仍与窗口大小相当。

经此一思,我们也就能理解Restormer做对了什么事情?
从Spatial-variant的QK map变成了Channel-Variant的QK map。
进一步我们也就能理解NAFNet做对了什么事情?
将Restormer的Transposed Attention等价于Channel attention。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值