官方源码
在看完文章后,在如何实现文中架构有一些好奇,在看完代码后一一进行解答:
1:如何通过代码进行多尺度编码?
2:经过多尺度编码后如何输入进多路径transformer模块?
3:transformer encoder中的factorized MHSA是什么以及如何实现收的?
4:输入到局部到全局特征交互模块的是序列还是图片?
对于第一个问题:
使用相对位置编码,在文中我们知道通过不同大小的卷积尺寸在进行patch enbedding,
class ConvRelPosEnc(nn.Module):
"""Convolutional relative position encoding."""
def __init__(self, Ch, h, window): # window在调用ConvRelPosEnc给出
super().__init__()
if isinstance(window, int):
# Set the same window size for all attention heads.
window = {window: h} ***# window = crpe_window={3: 2, 5: 3, 7: 3}***
self.window = window
# window是字典执行elif语句
elif isinstance(window, dict):
self.window = window
else:
raise ValueError()
self.conv_list = nn.ModuleList()
self.head_splits = []
for cur_window, cur_head_split in window.items(): # items返回可遍历的(键, 值) 元组数组,则返回三组,分别为cur_window, cur_head_split=3 2,5 3,7 3
dilation = 1 # Use dilation=1 at default.
padding_size = (cur_window + (cur_window - 1) *
(dilation - 1)) // 2
# padding_size = (1,2,3)
cur_conv = nn.Conv2d(
cur_head_split * Ch,
cur_head_split * Ch,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
dilation=(dilation, dilation),
groups=cur_head_split * Ch,
)
self.conv_list.append(cur_conv)
self.head_splits.append(cur_head_split)
self.channel_splits = [x * Ch for x in self.head_splits]
def forward(self, q, v, size):
"""foward function"""
B, h, N, Ch = q.shape
H, W = size
# We don't use CLS_TOKEN
q_img = q
v_img = v
# Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. rearrange用来重新指定函数的维度
v_img = rearrange(v_img, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W)
# Split according to channels.
# 在h*ch这个维度上进行划分
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # 按照h*Ch这个维度划分为self.channel_splits块
conv_v_img_list = [
conv(x) for conv, x in zip(self.conv_list, v_img_list)
]
conv_v_img = torch.cat(conv_v_img_list, dim=1)
# Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
conv_v_img = rearrange(conv_v_img, "B (h Ch) H W -> B h (H W) Ch", h=h)
EV_hat_img = q_img * conv_v_img
EV_hat = EV_hat_img
return EV_hat
①:通过遍历将三组卷积(输入输出通道为2Ch,3Ch,3Ch,卷积核大小分别为3,5,7)添加到空列表当中self.conv_list.append(cur_conv),将cur_head_split添加到self.head_splits空列表当中,通过遍历,self.channel_splits=[2Ch,3Ch,3Ch]。
②:在forward函数中,将V(就是qkv的v)进行划分,会生成3个([B, hCh/3, H, W])的列表(h为总头数,CH为总维度除以总头数后每个头的维度,h x CH为总维度),将划分后的V,分别输入conv,得到的三个结果按维度1再拼接起来,维度又变成了[B, hCh, H, W],将维度转换为[B, h, HW, Ch],最后再乘以q(维度为[B, h, HW, Ch]),最后结果的维度仍然为[B, h, H*W, Ch]。
对于第三个问题:
class FactorAtt_ConvRelPosEnc(nn.Module):
"""Factorized attention with convolutional relative position encoding class."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
shared_crpe=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# Shared convolutional relative position encoding.
self.crpe = shared_crpe
def forward(self, x, size):
B, N, C = x.shape
# Generate Q, K, V.
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
.contiguous()
) # Shape: [3, B, h, N, Ch].
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch].
# Factorized attention.
k_softmax = k.softmax(dim=2) # Softmax on dim N.
# 爱因斯坦求和公式,计算矩阵的乘法,相乘再相加,输入维度(b,h,n,ch),(b,h,n,ch)输出维度为(b,h,ch,ch)
k_softmax_T_dot_v = einsum(
"b h n k, b h n v -> b h k v", k_softmax, v
) # Shape: [B, h, Ch, Ch].
# 输入维度(b,h,n,ch),(b,h,ch,ch)
factor_att = einsum(
"b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v
) # Shape: [B, h, N, Ch].
# Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch].
# Merge and reshape.
x = self.scale * factor_att + crpe
x = (
x.transpose(1, 2).reshape(B, N, C).contiguous()
) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C].
# Output projection.
x = self.proj(x)
x = self.proj_drop(x)
return x
根据文中的公式:
我们首先获得q,k,v,然后对k进行softmax,接着我们计算最外层括号里面的,然后计算和q相乘。
# 爱因斯坦求和公式,计算矩阵的乘法,相乘再相加,输入维度(b,h,n,ch),(b,h,n,ch)输出维度为(b,h,ch,ch)
k_softmax_T_dot_v = einsum(
"b h n k, b h n v -> b h k v", k_softmax, v
) # Shape: [B, h, Ch, Ch].
# 输入维度(b,h,n,ch),(b,h,ch,ch)
factor_att = einsum(
"b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v
) # Shape: [B, h, N, Ch].
接着我们计算除以根号下C( self.scale = qk_scale or head_dim ** -0.5)
x = self.scale * factor_att + crpe
在模块中经过相对位置编码后,我们依次按流程走。
class MHCABlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=3,
drop_path=0.0,
qkv_bias=True,
qk_scale=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
shared_cpe=None,
shared_crpe=None,
):
super().__init__()
self.cpe = shared_cpe
self.crpe = shared_crpe
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
shared_crpe=shared_crpe,
)
self.mlp = Mlp(in_features=dim, hidden_features=dim * mlp_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
def forward(self, x, size):
# x.shape = [B, N, C]
if self.cpe is not None:
x = self.cpe(x, size)
cur = self.norm1(x)
x = x + self.drop_path(self.factoratt_crpe(cur, size))
cur = self.norm2(x)
x = x + self.drop_path(self.mlp(cur))
return x
由于在每一个stage中,transformer的数量不一样,我们通过遍历循环(for idx in range(self.num_layers))来添加个数,输入的x并行的通过transformer block,然后将序列reshape为图片。
class MHCAEncoder(nn.Module):
def __init__(
self,
dim,
num_layers=1,
num_heads=8,
mlp_ratio=3,
drop_path_list=[],
qk_scale=None,
crpe_window={3: 2, 5: 3, 7: 3},
):
super().__init__()
self.num_layers = num_layers
self.cpe = ConvPosEnc(dim, k=3)
self.crpe = ConvRelPosEnc(Ch=dim // num_heads, h=num_heads, window=crpe_window)
self.MHCA_layers = nn.ModuleList(
[
MHCABlock(
dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_path=drop_path_list[idx],
qk_scale=qk_scale,
shared_cpe=self.cpe,
shared_crpe=self.crpe,
)
for idx in range(self.num_layers)
]
)
def forward(self, x, size):
H, W = size
B = x.shape[0]
# x' shape : [B, N, C]
for layer in self.MHCA_layers:
x = layer(x, (H, W))
***这里回答了第四个问题:***
# return x's shape : [B, N, C] -> [B, C, H, W]
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x
最后看一下我们的输入流程:
class MHCA_stage(nn.Module):
def __init__(
self,
embed_dim,
out_embed_dim,
num_layers=1,
num_heads=8,
mlp_ratio=3,
num_path=4,
norm_cfg=dict(type="BN"),
drop_path_list=[],
):
super().__init__()
self.mhca_blks = nn.ModuleList(
[
MHCAEncoder(
embed_dim,
num_layers,
num_heads,
mlp_ratio,
drop_path_list=drop_path_list,
)
for _ in range(num_path)
]
)
self.InvRes = ResBlock(
in_features=embed_dim, out_features=embed_dim, norm_cfg=norm_cfg
)
self.aggregate = Conv2d_BN(
embed_dim * (num_path + 1),
out_embed_dim,
act_layer=nn.Hardswish,
norm_cfg=norm_cfg,
)
def forward(self, inputs):
att_outputs = [self.InvRes(inputs[0])]
for x, encoder in zip(inputs, self.mhca_blks):
# [B, C, H, W] -> [B, N, C]
_, _, H, W = x.shape
# x的第二个维度(行)进行展开,然后进行转置,行列互换
x = x.flatten(2).transpose(1, 2).contiguous()
att_outputs.append(encoder(x, size=(H, W)))
out_concat = torch.cat(att_outputs, dim=1)
out = self.aggregate(out_concat)
return out
我们的输入是input,首先经过self.InvRes函数,即convolutional local feature模块,然后其余的输入分别经过transformer encoder模块,最后全部放到att_outputs 列表中,将它们拼接,经过global to local feature interaction模块。
整个模型如下:
主要就是分别经过self.patch_embed_stages 后再分别经过self.mhca_stages,模型搭建到此结束。
class MPViT(nn.Module):
"""Multi-Path ViT class."""
def __init__(
self,
num_classes=80,
in_chans=3,
num_stages=4,
num_layers=[1, 1, 1, 1],
mlp_ratios=[8, 8, 4, 4],
num_path=[4, 4, 4, 4],
embed_dims=[64, 128, 256, 512],
num_heads=[8, 8, 8, 8],
drop_path_rate=0.0,
norm_cfg=dict(type="BN"),
norm_eval=True,
pretrained=None,
):
super().__init__()
self.num_classes = num_classes
self.num_stages = num_stages
self.conv_norm_cfg = norm_cfg
self.norm_eval = norm_eval
dpr = dpr_generator(drop_path_rate, num_layers, num_stages)
self.stem = nn.Sequential(
Conv2d_BN(
in_chans,
embed_dims[0] // 2,
kernel_size=3,
stride=2,
pad=1,
act_layer=nn.Hardswish,
norm_cfg=self.conv_norm_cfg,
),
Conv2d_BN(
embed_dims[0] // 2,
embed_dims[0],
kernel_size=3,
stride=2,
pad=1,
act_layer=nn.Hardswish,
norm_cfg=self.conv_norm_cfg,
),
)
# Patch embeddings.
self.patch_embed_stages = nn.ModuleList(
[
Patch_Embed_stage(
embed_dims[idx],
num_path=num_path[idx],
isPool=False if idx == 0 else True,
norm_cfg=self.conv_norm_cfg,
)
for idx in range(self.num_stages)
]
)
# Multi-Head Convolutional Self-Attention (MHCA)
self.mhca_stages = nn.ModuleList(
[
MHCA_stage(
embed_dims[idx],
embed_dims[idx + 1]
if not (idx + 1) == self.num_stages
else embed_dims[idx],
num_layers[idx],
num_heads[idx],
mlp_ratios[idx],
num_path[idx],
norm_cfg=self.conv_norm_cfg,
drop_path_list=dpr[idx],
)
for idx in range(self.num_stages)
]
)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError("pretrained must be a str or None")
def forward_features(self, x):
# x's shape : [B, C, H, W]
outs = []
x = self.stem(x) # Shape : [B, C, H/4, W/4]
for idx in range(self.num_stages):
att_inputs = self.patch_embed_stages[idx](x)
x = self.mhca_stages[idx](att_inputs)
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
return x
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(MPViT, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()