检记录学习~
在Swin Transformer中的划分窗口操作我个人认为实际上基于窗口的自注意力机制就是将原本的图像大小除以窗口的大小从而获得总的窗口的数量然后合并到Batch_size这个维度上。
def window_partition(x, window_size: int):
"""
将feature map按照window_size划分成一个个没有重叠的window
Args:
x: (B, H, W, C)
window_size (int): window size(M)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
将一个个window还原成一个feature map
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size(M)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
在划分完窗口送进Transformer块后需要调整回原feature map的样子然后再进行下一步操作。这里的还原操作要一步一步来是因为直接 reshape
会把所有的元素按顺序重排,这样可能无法保持窗口之间正确的空间关系,导致恢复后的图像丧失空间结构。其次,图像在被拆分成窗口时可能存在边界处理、填充等问题,简单的 reshape
无法处理这些复杂情况,因此需要使用 view
和 permute
等步骤来确保每个窗口按照正确的顺序和位置恢复。每个窗口提取的是局部特征,直接 reshape
可能会丢失这些局部特征在全局中的关联性。