即插即用模块之Iformer

论文:https://arxiv.org/abs/2205.12956
源代码:GitHub - sail-sg/iFormer: iFormer: Inception Transformer
我改成了可以对任意大小图像进行处理,返回的是处理后的图像(非原论文中分类个数)
但对于大图像要求显存可能较大(我跑不起来,显示OOM)
class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """

    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=1)  # 自适应卷积
        self.norm = nn.BatchNorm2d(embed_dim)
    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)
        return x


class FirstPatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()

        self.proj1 = nn.Conv2d(in_chans, embed_dim // 2, kernel_size=1)  # 自适应卷积
        self.norm1 = nn.BatchNorm2d(embed_dim // 2)
        self.gelu1 = nn.GELU()
        self.proj2 = nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=1)  # 自适应卷积
        self.norm2 = nn.BatchNorm2d(embed_dim)
    def forward(self, x):
        x = self.proj1(x)
        x = self.norm1(x)
        x = self.gelu1(x)
        x = self.proj2(x)
        x = self.norm2(x)
        x = x.permute(0, 2, 3, 1)
        # 1 480 720 32
        return x


class HighMixer(nn.Module):
    def __init__(self, dim, **kwargs):
        super().__init__()

        self.cnn_in = cnn_in = dim // 2
        self.pool_in = pool_in = dim // 2

        self.cnn_dim = cnn_dim = cnn_in * 2
        self.pool_dim = pool_dim = pool_in * 2

        self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False)
        self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=3, stride=1, padding=1, bias=False,
                               groups=cnn_dim)
        self.mid_gelu1 = nn.GELU()

        self.Maxpool = nn.MaxPool2d(3, stride=1, padding=1)
        self.proj2 = nn.Conv2d(pool_in, pool_dim, kernel_size=1, stride=1, padding=0)
        self.mid_gelu2 = nn.GELU()

    def forward(self, x):
        # B, C H, W
        cx = x[:, :self.cnn_in, :, :].contiguous()
        cx = self.conv1(cx)
        cx = self.proj1(cx)
        cx = self.mid_gelu1(cx)

        px = x[:, self.cnn_in:, :, :].contiguous()
        px = self.Maxpool(px)
        px = self.proj2(px)
        px = self.mid_gelu2(px)

        hx = torch.cat((cx, px), dim=1)
        return hx


class LowMixer(nn.Module):
    #num_head是参数 pool_size是参数
    def __init__(self, dim, num_heads,**kwargs ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.dim = dim

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(0.2)
        pool_size = 2
        self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0,
                                 count_include_pad=False) if pool_size > 1 else nn.Identity()
        self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity()

    def att_fun(self, q, k, v, B, N, C):
        # print(q.shape)torch.Size([1, 2, 86400, 4])
        # print(k.shape)torch.Size([1, 2, 86400, 4])
        attn = (q @ k.transpose(-2, -1)) * self.scale
        print(attn.shape)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(2, 3).reshape(B, C, N)
        #x 1 8 8640
        return x

    def forward(self, x):
        # B, C, H, W 1 8 480 720
        B, _, _, _ = x.shape
        xa = self.pool(x)
       # print(xa.shape)#torch.Size([1, 8, 240, 360])
        _,_,ha,wa=xa.shape
        xa = xa.permute(0, 2, 3, 1).view(B, -1, self.dim)
        # print(xa.shape)#torch.Size([1, 86400, 8])
        B, N, C = xa.shape
        qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
        #torch.Size([1, 2, 86400, 4])
        xa = self.att_fun(q, k, v, B, N, C)
        xa = xa.view(B, C,ha ,wa )  # .permute(0, 3, 1, 2)
        xa = self.uppool(xa)
        # 1 8 480 720
        return xa


class Mixer(nn.Module):
    def __init__(self, dim, **kwargs, ):
        super().__init__()
        num_heads = 8
        attention_head = 2
        proj_drop = 0.2
        self.num_heads = num_heads
        self.head_dim = head_dim = dim // num_heads

        self.low_dim = low_dim = attention_head * head_dim
        self.high_dim = high_dim = dim - low_dim

        self.high_mixer = HighMixer(high_dim)
        self.low_mixer = LowMixer(low_dim, num_heads=attention_head )

        self.conv_fuse = nn.Conv2d(low_dim + high_dim * 2, low_dim + high_dim * 2, kernel_size=3, stride=1, padding=1,
                                   bias=False, groups=low_dim + high_dim * 2)
        self.proj = nn.Conv2d(low_dim + high_dim * 2, dim, kernel_size=1, stride=1, padding=0)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        x = x.permute(0, 3, 1, 2)
        hx = x[:, :self.high_dim, :, :].contiguous()
        hx = self.high_mixer(hx)
        print(hx.shape)# 1 48 480 720
        lx = x[:, self.high_dim:, :, :].contiguous()
        lx = self.low_mixer(lx)
        # 1 8 480 720
        x = torch.cat((hx, lx), dim=1)
        x = x + self.conv_fuse(x)
        x = self.proj(x)
        x = self.proj_drop(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4.,use_layer_scale=False, drop_path=0.,layer_scale_init_value=1e-5,
                 ):
        super().__init__()
        norm_layer = nn.LayerNorm
        act_layer = nn.GELU
        attn = Mixer
        self.norm1 = norm_layer(dim)

        self.attn = attn(dim, num_heads=num_heads )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)

        self.use_layer_scale = use_layer_scale
        if self.use_layer_scale:
            # print('use layer scale init value {}'.format(layer_scale_init_value))
            self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
            x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class InceptionTransformer(nn.Module):
    def __init__(self, H,W, in_chans=3, embed_dims=None,  num_heads=None, mlp_ratio=4., weight_init='',
                 **kwargs,
                 ):

        super().__init__()
        embed_layer = PatchEmbed

        norm_layer =partial(nn.LayerNorm, eps=1e-6)

        dpr = [x.item() for x in torch.linspace(0, 0.2, 40)]  # stochastic depth decay rule
        patch_size = 4
        self.patch_embed = FirstPatchEmbed(in_chans=in_chans, embed_dim=embed_dims)
        self.num11_patches = ( H// patch_size)
        self.num12_patches = (W//patch_size)
        self.pos_embed1 = nn.Parameter(torch.zeros(1, self.num11_patches, self.num12_patches, embed_dims))
        self.blocks1 = nn.Sequential(*[
            Block(
                dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio , drop_path=dpr[i] )
            # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value,
            # )
            for i in range(0, 20)])

        self.patch_embed2 = embed_layer(in_chans=embed_dims, embed_dim=embed_dims)
        self.num21_patches = self.num11_patches // 2
        self.num22_patches = self.num12_patches // 2
        self.pos_embed2 = nn.Parameter(torch.zeros(1, self.num21_patches, self.num22_patches, embed_dims))
        self.blocks2 = nn.Sequential(*[
            Block(
                dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_path=dpr[i])
            # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value,
            # )
            for i in range(20, 40)])

        self.norm = norm_layer(embed_dims)
        # Classifier head(s)
        # set post block, for example, class attention layers
        self.init_weights(weight_init)

    def init_weights(self, mode=''):
        trunc_normal_(self.pos_embed1, std=.02)
        trunc_normal_(self.pos_embed2, std=.02)

    def _get_pos_embed(self, pos_embed, numh_patches,numw_patches, H, W):
        if H * W == numh_patches * numw_patches:
            return pos_embed
        else:
            return F.interpolate(
                pos_embed.permute(0, 3, 1, 2),
                size=(H, W), mode="bilinear").permute(0, 2, 3, 1)

    def forward_features(self, x):
        x = self.patch_embed(x)
        B, H, W, C = x.shape
        print(x.shape)#torch.Size([1, 480, 720, 32])
        x = x + self._get_pos_embed(self.pos_embed1, self.num11_patches,self.num12_patches, H, W)
        x = self.blocks1(x)
        # 1 32 480 720
        x = x.permute(0, 3, 1, 2)
        x = self.patch_embed2(x)
        B, H, W, C = x.shape
        x = x + self._get_pos_embed(self.pos_embed2, self.num21_patches,self.num22_patches,H, W)
        x = self.blocks2(x)
        # 1 480 720 32
        x = x.flatten(1, 2)
        x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', b=B,h=H, w=W)

    def forward(self, x):
        print(x.shape)
        x = self.forward_features(x)
        return x
  • 18
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值