class GroupBlock(TimmBlock): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=1): #super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, # drop_path, act_layer, norm_layer) #delete the qk_scale super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer) del self.attn if ws == 1: self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, sr_ratio) else: self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, ws) def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
VMBNet
于 2025-04-23 10:55:01 首次发布