创新点(出自原文)
our PVT overcomes the difficulties of the conventional Transformer by (1) taking fine-grained image patches (i.e., 4×4 pixels per patch) as input to learn high-resolution representation, which is essential for dense prediction tasks; (2) introducing a progressive shrinking pyramid to reduce the sequence length of Transformer as the network deepens, significantly reducing the computational cost, and (3) adopting a spatial-reduction attention (SRA) layer to further reduce the resource consumption when learning high-resolution features.
PVT有四个stage,可以看到每个stage由patch emd和encoder组成。
1)分辨率是怎样降低的?
上图可以看出每经过一个stage,输入的分辨率都会降低,是patch emd的效果。那么是怎样降低分辨率的?
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
# assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
# f"img_size {img_size} should be divided by patch_size {patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]
return x, (H, W)
关键代码:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
实际上是简单的有步长的卷积。例如H×W变成H/4×W/4,只需要stride为4就可以。
解释原文内容 (1) taking fine-grained image patches (i.e., 4×4 pixels per patch) as input to learn high-resolution representation, which is essential for dense prediction tasks
细粒度patch作为输入,意思是4×4 pixels作为一个token,而不是VIT中的16×16或者32×32 pixels作为一个token。
解释原文内容(2) introducing a progressive shrinking pyramid to reduce the sequence length of Transformer as the network deepens, significantly reducing the computational cost
每经过一次stage,空间分辨率都会降低,根据自注意力计算公式:计算量减小。
(详细解释可以看,作者写的非常详细。Swin-Transformer中MSA和W-MSA模块计算复杂度推导(非常详细,最新)_小周_的博客-CSDN博客)
2)a spatial-reduction attention (SRA) layer是什么?
我们首先要清楚transformer中的自注意力是什么?序列长度是什么?特征个数是什么?
自注意力的输入序列的格式为(序列长度,特征个数),我们将其表示为(HW,C)。SRA指的是在自注意力Attention计算过程中,通过减小序列长度来减少计算量,也就是(3)adopting a spatial-reduction attention (SRA) layer to further reduce the resource consumption when learning high-resolution features. 序列长度实际上就是每个stage输入的数据大小。
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
关键代码:
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
从代码中可以看出,SRA与普通attention 的区别是,k和v的序列长度不再是输入的HW,而是根据stride的大小进行了缩小。为什么KV序列长度缩小减少资源消耗呢,还是移步看计算量的具体解释。
3)思考
虽然每经过一次stage,分辨率都会降低,但是对于每个stage来说,图像的粒度始终都是相同的。所以才出现了跨尺度Attention吧。