论文:https://arxiv.org/abs/2205.12956 源代码:GitHub - sail-sg/iFormer: iFormer: Inception Transformer 我改成了可以对任意大小图像进行处理,返回的是处理后的图像(非原论文中分类个数) 但对于大图像要求显存可能较大(我跑不起来,显示OOM) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, in_chans=3, embed_dim=768): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=1) # 自适应卷积 self.norm = nn.BatchNorm2d(embed_dim) def forward(self, x): x = self.proj(x) x = self.norm(x) x = x.permute(0, 2, 3, 1) return x class FirstPatchEmbed(nn.Module): def __init__(self, in_chans=3, embed_dim=768): super().__init__() self.proj1 = nn.Conv2d(in_chans, embed_dim // 2, kernel_size=1) # 自适应卷积 self.norm1 = nn.BatchNorm2d(embed_dim // 2) self.gelu1 = nn.GELU() self.proj2 = nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=1) # 自适应卷积 self.norm2 = nn.BatchNorm2d(embed_dim) def forward(self, x): x = self.proj1(x) x = self.norm1(x) x = self.gelu1(x) x = self.proj2(x) x = self.norm2(x) x = x.permute(0, 2, 3, 1) # 1 480 720 32 return x class HighMixer(nn.Module): def __init__(self, dim, **kwargs): super().__init__() self.cnn_in = cnn_in = dim // 2 self.pool_in = pool_in = dim // 2 self.cnn_dim = cnn_dim = cnn_in * 2 self.pool_dim = pool_dim = pool_in * 2 self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False) self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=3, stride=1, padding=1, bias=False, groups=cnn_dim) self.mid_gelu1 = nn.GELU() self.Maxpool = nn.MaxPool2d(3, stride=1, padding=1) self.proj2 = nn.Conv2d(pool_in, pool_dim, kernel_size=1, stride=1, padding=0) self.mid_gelu2 = nn.GELU() def forward(self, x): # B, C H, W cx = x[:, :self.cnn_in, :, :].contiguous() cx = self.conv1(cx) cx = self.proj1(cx) cx = self.mid_gelu1(cx) px = x[:, self.cnn_in:, :, :].contiguous() px = self.Maxpool(px) px = self.proj2(px) px = self.mid_gelu2(px) hx = torch.cat((cx, px), dim=1) return hx class LowMixer(nn.Module): #num_head是参数 pool_size是参数 def __init__(self, dim, num_heads,**kwargs ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.dim = dim self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(0.2) pool_size = 2 self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0, count_include_pad=False) if pool_size > 1 else nn.Identity() self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity() def att_fun(self, q, k, v, B, N, C): # print(q.shape)torch.Size([1, 2, 86400, 4]) # print(k.shape)torch.Size([1, 2, 86400, 4]) attn = (q @ k.transpose(-2, -1)) * self.scale print(attn.shape) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(2, 3).reshape(B, C, N) #x 1 8 8640 return x def forward(self, x): # B, C, H, W 1 8 480 720 B, _, _, _ = x.shape xa = self.pool(x) # print(xa.shape)#torch.Size([1, 8, 240, 360]) _,_,ha,wa=xa.shape xa = xa.permute(0, 2, 3, 1).view(B, -1, self.dim) # print(xa.shape)#torch.Size([1, 86400, 8]) B, N, C = xa.shape qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) #torch.Size([1, 2, 86400, 4]) xa = self.att_fun(q, k, v, B, N, C) xa = xa.view(B, C,ha ,wa ) # .permute(0, 3, 1, 2) xa = self.uppool(xa) # 1 8 480 720 return xa class Mixer(nn.Module): def __init__(self, dim, **kwargs, ): super().__init__() num_heads = 8 attention_head = 2 proj_drop = 0.2 self.num_heads = num_heads self.head_dim = head_dim = dim // num_heads self.low_dim = low_dim = attention_head * head_dim self.high_dim = high_dim = dim - low_dim self.high_mixer = HighMixer(high_dim) self.low_mixer = LowMixer(low_dim, num_heads=attention_head ) self.conv_fuse = nn.Conv2d(low_dim + high_dim * 2, low_dim + high_dim * 2, kernel_size=3, stride=1, padding=1, bias=False, groups=low_dim + high_dim * 2) self.proj = nn.Conv2d(low_dim + high_dim * 2, dim, kernel_size=1, stride=1, padding=0) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, H, W, C = x.shape x = x.permute(0, 3, 1, 2) hx = x[:, :self.high_dim, :, :].contiguous() hx = self.high_mixer(hx) print(hx.shape)# 1 48 480 720 lx = x[:, self.high_dim:, :, :].contiguous() lx = self.low_mixer(lx) # 1 8 480 720 x = torch.cat((hx, lx), dim=1) x = x + self.conv_fuse(x) x = self.proj(x) x = self.proj_drop(x) x = x.permute(0, 2, 3, 1).contiguous() return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.,use_layer_scale=False, drop_path=0.,layer_scale_init_value=1e-5, ): super().__init__() norm_layer = nn.LayerNorm act_layer = nn.GELU attn = Mixer self.norm1 = norm_layer(dim) self.attn = attn(dim, num_heads=num_heads ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) self.use_layer_scale = use_layer_scale if self.use_layer_scale: # print('use layer scale init value {}'.format(layer_scale_init_value)) self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x))) x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class InceptionTransformer(nn.Module): def __init__(self, H,W, in_chans=3, embed_dims=None, num_heads=None, mlp_ratio=4., weight_init='', **kwargs, ): super().__init__() embed_layer = PatchEmbed norm_layer =partial(nn.LayerNorm, eps=1e-6) dpr = [x.item() for x in torch.linspace(0, 0.2, 40)] # stochastic depth decay rule patch_size = 4 self.patch_embed = FirstPatchEmbed(in_chans=in_chans, embed_dim=embed_dims) self.num11_patches = ( H// patch_size) self.num12_patches = (W//patch_size) self.pos_embed1 = nn.Parameter(torch.zeros(1, self.num11_patches, self.num12_patches, embed_dims)) self.blocks1 = nn.Sequential(*[ Block( dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio , drop_path=dpr[i] ) # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, # ) for i in range(0, 20)]) self.patch_embed2 = embed_layer(in_chans=embed_dims, embed_dim=embed_dims) self.num21_patches = self.num11_patches // 2 self.num22_patches = self.num12_patches // 2 self.pos_embed2 = nn.Parameter(torch.zeros(1, self.num21_patches, self.num22_patches, embed_dims)) self.blocks2 = nn.Sequential(*[ Block( dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_path=dpr[i]) # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, # ) for i in range(20, 40)]) self.norm = norm_layer(embed_dims) # Classifier head(s) # set post block, for example, class attention layers self.init_weights(weight_init) def init_weights(self, mode=''): trunc_normal_(self.pos_embed1, std=.02) trunc_normal_(self.pos_embed2, std=.02) def _get_pos_embed(self, pos_embed, numh_patches,numw_patches, H, W): if H * W == numh_patches * numw_patches: return pos_embed else: return F.interpolate( pos_embed.permute(0, 3, 1, 2), size=(H, W), mode="bilinear").permute(0, 2, 3, 1) def forward_features(self, x): x = self.patch_embed(x) B, H, W, C = x.shape print(x.shape)#torch.Size([1, 480, 720, 32]) x = x + self._get_pos_embed(self.pos_embed1, self.num11_patches,self.num12_patches, H, W) x = self.blocks1(x) # 1 32 480 720 x = x.permute(0, 3, 1, 2) x = self.patch_embed2(x) B, H, W, C = x.shape x = x + self._get_pos_embed(self.pos_embed2, self.num21_patches,self.num22_patches,H, W) x = self.blocks2(x) # 1 480 720 32 x = x.flatten(1, 2) x = self.norm(x) x = rearrange(x, 'b (h w) c -> b c h w', b=B,h=H, w=W) def forward(self, x): print(x.shape) x = self.forward_features(x) return x
即插即用模块之Iformer
于 2024-03-26 19:31:32 首次发布