classLocalPerceptionUint(nn.Module):def__init__(self, dim, act=False):super(LocalPerceptionUint, self).__init__()
self.act = act
# 增强本地信息的提取
self.conv_3x3_dw = ConvDW3x3(dim)if self.act:
self.actation = nn.Sequential(
nn.GELU(),
nn.BatchNorm2d(dim))defforward(self, x):if self.act:
out = self.actation(self.conv_3x3_dw(x))return out
else:
out = self.conv_3x3_dw(x)return out
IRFFN
classInvertedResidualFeedForward(nn.Module):def__init__(self, dim, dim_ratio=4.):super(InvertedResidualFeedForward, self).__init__()
output_dim =int(dim_ratio * dim)
self.conv1x1_gelu_bn = ConvGeluBN(
in_channel=dim,
out_channel=output_dim,
kernel_size=1,
stride_size=1,
padding=0)
self.conv3x3_dw = ConvDW3x3(dim=output_dim)
self.act = nn.Sequential(
nn.GELU(),
nn.BatchNorm2d(output_dim))
self.conv1x1_pw = nn.Sequential(
nn.Conv2d(output_dim, dim,1,1,0),
nn.BatchNorm2d(dim))defforward(self, x):
x = self.conv1x1_gelu_bn(x)
out = x + self.act(self.conv3x3_dw(x))
out = self.conv1x1_pw(out)return out
LMHSA
classLightMutilHeadSelfAttention(nn.Module):"""calculate the self attention with down sample the resolution for k, v, add the relative position bias before softmax
Args:
dim (int) : features map channels or dims
num_heads (int) : attention heads numbers
relative_pos_embeeding (bool) : relative position embeeding
no_distance_pos_embeeding (bool): no_distance_pos_embeeding
features_size (int) : features shape
qkv_bias (bool) : if use the embeeding bias
qk_scale (float) : qk scale if None use the default
attn_drop (float) : attention dropout rate
proj_drop (float) : project linear dropout rate
sr_ratio (float) : k, v resolution downsample ratio
Returns:
x : LMSA attention result, the shape is (B, H, W, C) that is the same as inputs.
"""def__init__(self, dim, num_heads=8, features_size=56,
relative_pos_embeeding=False, no_distance_pos_embeeding=False, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., sr_ratio=1.):super(LightMutilHeadSelfAttention, self).__init__()assert dim % num_heads ==0,f"dim {dim} should be divided by num_heads {num_heads}"
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads # used for each attention heads
self.scale = qk_scale or head_dim **-0.5
self.relative_pos_embeeding = relative_pos_embeeding
self.no_distance_pos_embeeding = no_distance_pos_embeeding
self.features_size = features_size
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim*2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.sr_ratio = sr_ratio
if sr_ratio >1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)if self.relative_pos_embeeding:
self.relative_indices = generate_relative_distance(self.features_size)
self.position_embeeding = nn.Parameter(torch.randn(2* self.features_size -1,2* self.features_size -1))elif self.no_distance_pos_embeeding:
self.position_embeeding = nn.Parameter(torch.randn(self.features_size **2, self.features_size **2))else:
self.position_embeeding =Noneif self.position_embeeding isnotNone:
trunc_normal_(self.position_embeeding, std=0.2)defforward(self, x):
B, C, H, W = x.shape
N = H*W
x_q = rearrange(x,'B C H W -> B (H W) C')# translate the B,C,H,W to B (H X W) C
q = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0,2,1,3)# B,N,H,DIM -> B,H,N,DIM# conv for down sample the x resoution for the k, vif self.sr_ratio >1:
x_reduce_resolution = self.sr(x)
x_kv = rearrange(x_reduce_resolution,'B C H W -> B (H W) C ')
x_kv = self.norm(x_kv)else:
x_kv = rearrange(x,'B C H W -> B (H W) C ')
kv_emb = rearrange(self.kv(x_kv),'B N (dim h l ) -> l B h N dim', h=self.num_heads, l=2)# 2 B H N DIM
k, v = kv_emb[0], kv_emb[1]
attn =(q @ k.transpose(-2,-1))* self.scale # (B H Nq DIM) @ (B H DIM Nk) -> (B H NQ NK)# TODO: add the relation position bias, because the k_n != q_n, we need to split the position embeeding matrix
q_n, k_n = q.shape[1], k.shape[2]if self.relative_pos_embeeding:
attn = attn + self.position_embeeding[self.relative_indices[:,:,0], self.relative_indices[:,:,1]][:,:k_n]elif self.no_distance_pos_embeeding:
attn = attn + self.position_embeeding[:,:k_n]
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x =(attn @ v).transpose(1,2).reshape(B, N, C)# (B H NQ NK) @ (B H NK dim) -> (B NQ H*DIM)
x = self.proj(x)
x = self.proj_drop(x)
x = rearrange(x,'B (H W) C -> B C H W ', H=H, W=W)return x