LeMeViT:具有可学习元令牌的高效ViT

本文提出使用可学习的元令牌来制定稀疏令牌,这有效地学习了关键信息,同时提高了推理速度。从技术上讲,主题标记首先通过交叉关注从图像标记中初始化。提出了双交叉注意(DCA)来促进图像令牌和元令牌之间的信息交换,其中它们在双分支结构中交替充当查询和密钥(值)令牌,与自注意相比,显著降低了计算复杂度。通过在具有密集视觉标记的早期阶段使用DCA,获得了不同大小的分层结构LeMeViT。在分类和密集预测任务中的实验结果表明,与baseline相比,LeMeViT具有1.7倍的显著加速、更少的参数和有竞争力的性能,并在效率和性能之间实现了更好的权衡。

        现有方法通常使用下采样或 clus tering 来减少当前块内的图像标记数量,这依赖于强先验或对并行计算不友好。而通过学习元标记稀疏地表示密集的图像标记。元代币通过计算高效的双交叉注意力块以端到端的方式与图像代币交换信息,促进信息分阶段流动。


 LeMeViT总结构:LeMeViT由三个不同的注意力块组成,从左到右排列为交叉注意力块、双交叉注意力块和标准注意力块。

通过代码来实现:

def scaled_dot_product_attention(q, k, v, scale=None):
    """Custom Scaled-Dot Product Attention
        dim (B h N d)
    """
    _,_,_,dim = q.shape
    scale = scale or dim**(-0.5)
    attn = q @ k.transpose(-1,-2) * scale
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        """
        x: NHWC tensor
        """
        x = x.permute(0, 3, 1, 2) #NCHW
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) #NHWC

        return x

class Attention(nn.Module):
    """Patch-to-Cluster Attention Layer"""
    
    def __init__(
        self,
        dim,
        num_heads,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0

        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x):
        if self.use_xformers:
            q = self.q(x)  # B N C
            k = self.k(x)  # B N C
            v = self.v(x)
            q = rearrange(q, "B N (h d) -> B N h d", h=self.num_heads)
            k = rearrange(k, "B N (h d) -> B N h d", h=self.num_heads)
            v = rearrange(v, "B N (h d) -> B N h d", h=self.num_heads)

            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)")

            x = self.proj(x)
        else:
            x = rearrange(x, "B N C -> N B C")

            x, attn = F.multi_head_attention_forward(
                query=x,
                key=x,
                value=x,
                embed_dim_to_check=x.shape[-1],
                num_heads=self.num_heads,
                q_proj_weight=self.q.weight,
                k_proj_weight=self.k.weight,
                v_proj_weight=self.v.weight,
                in_proj_weight=None,
                in_proj_bias=torch.cat([self.q.bias, self.k.bias, self.v.bias]),
                bias_k=None,
                bias_v=None,
                add_zero_attn=False,
                dropout_p=self.attn_drop,
                out_proj_weight=self.proj.weight,
                out_proj_bias=self.proj.bias,
                use_separate_proj_weight=True,
                training=self.training,
                need_weights=not self.training,  # for visualization
                average_attn_weights=False,
            )

            x = rearrange(x, "N B C -> B N C")

            if not self.training:
                attn = self.attn_viz(attn)

        x = self.proj_drop(x)

        return x

class StandardAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        
        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv = nn.Linear(dim, 3 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        self.scale = dim**0.5
        
    # @get_local('attn_map')
    def forward(self, x):
        if self.use_flash_attn:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            x = flash_attn_qkvpacked_func(qkv)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_xformers:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_torchfunc:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = F.scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        else:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        # with torch.no_grad():
        #     attn = (q @ k.transpose(-2, -1)) * self.scale
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Standard:", attn_map)
        return x


class DualCrossAttention(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv1 = nn.Linear(dim, 3 * dim)
        self.qkv2 = nn.Linear(dim, 3 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        
    # @get_local('attn_map')
    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> B M x h d", x=3, h=self.num_heads).contiguous()
            
            q1, kv1 = qkv1[:,:,0], qkv1[:,:,1:]
            q2, kv2 = qkv2[:,:,0], qkv2[:,:,1:]
            
            x = flash_attn_kvpacked_func(q1, kv2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_kvpacked_func(q2, kv1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B M h d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = xops.memory_efficient_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = F.scaled_dot_product_attention(q1, k2, v2)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(q2, k1, v1)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = scaled_dot_product_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        # with torch.no_grad():   
        #     # q1 = rearrange(q1, "B h M d -> B M (h d)").contiguous()
        #     # k2 = rearrange(k2, "B h M d -> B M (h d)").contiguous()
        #     attn = (q1 @ k2.transpose(-2, -1)) * scale_x
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Mix:", attn_map)
        return x, c
    
class DualCrossAttention_v2(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc
        
        self.qv1 = nn.Linear(dim, 2 * dim)
        self.kv2 = nn.Linear(dim, 2 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> B M x h d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[:,:,0], qv1[:,:,1]
            k, v2 = kv2[:,:,0], kv2[:,:,1]
            
            x = flash_attn_func(q, k, v2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_func(k, q, v1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]
            
            x = xops.memory_efficient_attention(q, k, v2, scale=scale_x)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(k, q, v1, scale=scale_c)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = F.scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        return x, c

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, 2 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 


    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        
        if self.use_flash_attn:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            
            c = flash_attn_kvpacked_func(q, kv)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_xformers:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B N h d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = xops.memory_efficient_attention(q, k, v)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_torchfunc:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = F.scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        else:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        return c


class LeMeBlock(nn.Module):
    def __init__(self, dim, 
                 attn_drop, proj_drop, drop_path=0., attn_type=None,
                 layer_scale_init_value=-1, num_heads=8, qk_dim=None, mlp_ratio=4, mlp_dwconv=False,
                 cpe_ks=3, pre_norm=True):
        super().__init__()
        qk_dim = qk_dim or dim

        # modules
        if cpe_ks > 0:
            self.pos_embed = nn.Conv2d(dim, dim,  kernel_size=cpe_ks, padding=1, groups=dim)
        else:
            self.pos_embed = lambda x: 0
        self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
        
        self.attn_type = attn_type
        if attn_type == "D":
            self.attn = DualCrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "D2":
            self.attn = DualCrossAttention_v2(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "S" or attn_type == None:
            self.attn = StandardAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "C":
            self.attn = CrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
            
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio*dim)),
                                 DWConv(int(mlp_ratio*dim)) if mlp_dwconv else nn.Identity(),
                                 nn.GELU(),
                                 nn.Linear(int(mlp_ratio*dim), dim)
                                )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # tricks: layer scale & pre_norm/post_norm
        if layer_scale_init_value > 0:
            self.use_layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
        else:
            self.use_layer_scale = False
        self.pre_norm = pre_norm
            
    def forward_with_xc(self, x, c):

        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(self.gamma1 * _x) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * _c) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(_x) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(_c) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(self.gamma1 * _x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * _c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(_x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(_c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
                
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_c(self, x, c):
        
        _, C, H, W = x.shape
        _x = x
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(x), self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                c = c + self.drop_path(self.attn(self.norm1(x),self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                c = self.norm1(c + self.drop_path(self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)

        x = _x
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_x(self, x, c):
        
        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c
    
    def forward(self, x, c):
        if self.attn_type == "D" or self.attn_type == "D2":
            return self.forward_with_xc(x,c)
        elif self.attn_type == "S":
            return self.forward_with_x(x,c)
        elif self.attn_type == "C":
            return self.forward_with_c(x,c)
        else:
            raise NotImplementedError("Attention type does not exit")


class LeMeViT(nn.Module):
    def __init__(self, 
                 depth=[2, 3, 4, 8, 3], 
                 in_chans=3, 
                 num_classes=1000, 
                 embed_dim=[64, 64, 128, 320, 512], 
                 head_dim=64, 
                 mlp_ratios=[4, 4, 4, 4, 4], 
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop=0., 
                 drop_path_rate=0.,
                 # <<<------
                 attn_type=["C","D","D","S","S"],
                 queries_len=128,
                 qk_dims=None,
                 cpe_ks=3,
                 pre_norm=True,
                 mlp_dwconv=False,
                 representation_size=None,
                 layer_scale_init_value=-1,
                 use_checkpoint_stages=[],
                 # ------>>>
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        qk_dims = qk_dims or embed_dim
        
        self.num_stages = len(attn_type)
        
        ############ downsample layers (patch embeddings) ######################
        self.downsample_layers = nn.ModuleList()
        # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv 
        stem = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0]),
        )

        if use_checkpoint_stages:
            stem = checkpoint_wrapper(stem)
        self.downsample_layers.append(stem)

        for i in range(self.num_stages-1):
            if attn_type[i] == "C":
                downsample_layer = nn.Identity()
            else:
                downsample_layer = nn.Sequential(
                    nn.Conv2d(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                    nn.BatchNorm2d(embed_dim[i+1])
                )
            if use_checkpoint_stages:
                downsample_layer = checkpoint_wrapper(downsample_layer)
            self.downsample_layers.append(downsample_layer)
        ##########################################################################


        #TODO: maybe remove last LN
        self.queries_len = queries_len
        self.meta_tokens = nn.Parameter(torch.randn(self.queries_len ,embed_dim[0]), requires_grad=True) 
        
        self.meta_token_downsample = nn.ModuleList()
        meta_token_downsample = nn.Sequential(
            nn.Linear(embed_dim[0], embed_dim[0] * 4),
            nn.LayerNorm(embed_dim[0] * 4),
            nn.GELU(),
            nn.Linear(embed_dim[0] * 4, embed_dim[0]),
            nn.LayerNorm(embed_dim[0])
        )
        self.meta_token_downsample.append(meta_token_downsample)
        for i in range(self.num_stages-1):
            meta_token_downsample = nn.Sequential(
                nn.Linear(embed_dim[i], embed_dim[i] * 4),
                nn.LayerNorm(embed_dim[i] * 4),
                nn.GELU(),
                nn.Linear(embed_dim[i] * 4, embed_dim[i+1]),
                nn.LayerNorm(embed_dim[i+1])
            )
            self.meta_token_downsample.append(meta_token_downsample)

        
        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        nheads= [dim // head_dim for dim in qk_dims]
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] 
        cur = 0
        for i in range(self.num_stages):
            stage = nn.ModuleList(
                [LeMeBlock(dim=embed_dim[i], 
                           attn_drop=attn_drop, proj_drop=drop_rate,
                           drop_path=dp_rates[cur + j],
                           attn_type=attn_type[i],
                           layer_scale_init_value=layer_scale_init_value,
                           num_heads=nheads[i],
                           qk_dim=qk_dims[i],
                           mlp_ratio=mlp_ratios[i],
                           mlp_dwconv=mlp_dwconv,
                           cpe_ks=cpe_ks,
                           pre_norm=pre_norm
                    ) for j in range(depth[i])],
            )
            if i in use_checkpoint_stages:
                stage = checkpoint_wrapper(stage)
            self.stages.append(stage)
            cur += depth[i]

        ##########################################################################
        self.norm = nn.BatchNorm2d(embed_dim[-1])
        self.norm_c = nn.LayerNorm(embed_dim[-1])
        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x, c):
        for i in range(self.num_stages): 
            x = self.downsample_layers[i](x)
            c = self.meta_token_downsample[i](c)
            for j, block in enumerate(self.stages[i]):
                x, c = block(x, c)
        x = self.norm(x)
        x = self.pre_logits(x)
        
        c = self.norm_c(c)
        c = self.pre_logits(c)

        # x = x.flatten(2).mean(-1,keepdim=True)
        # c = c.transpose(-2,-1).contiguous().mean(-1,keepdim=True)
        # x = torch.concat([x,c],dim=-1).mean(-1)

        x = x.flatten(2).mean(-1)
        c = c.transpose(-2,-1).contiguous().mean(-1)
        x = x + c

        return x

    def forward(self, x):
        B, _, H, W = x.shape 
        c = self.meta_tokens.repeat(B,1,1)
        x = self.forward_features(x, c)
        x = self.head(x)
        return x

同时,在整体架构的基础上,作者设计了三种不同大小的模型,即 Tiny、Small 和 Base。通过调整每个阶段的块数和特征的尺寸来定制这些尺寸,其他配置在所有变体之间共享。我们将每个关注的头尺寸设置为 32,MLP 扩展率为 4,条件位置编码核大小为 3。元标记的长度设置为 16。

@register_model
def lemevit_tiny(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 8, 2],
        embed_dim=[64, 64, 128, 192, 320], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_small(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 6, 2],
        embed_dim=[96, 96, 192, 320, 384], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_base(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[2, 4, 4, 18, 4],
        embed_dim=[96, 96, 192, 384, 512], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model

性能测试

现在浅浅试一下模型在图像分类上的表现,我选择其中的tiny和small两个版本。使用自制的葡萄多光条件数据集:包括4个类别和3种光照条件共12种数据集,分成8:1=训练集:测试集。

我的训练代码:

import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.lemevit import lemevit_small_v2
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
import pandas as pd
from torchvision.transforms import RandAugment
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
    model.train()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
        samples, targets = mixup_fn(data, target)
        output = model(data)
        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                loss = torch.nan_to_num(criterion_train(output, targets))
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss = criterion_train(output, targets)
            loss.backward()
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)
        torch.cuda.synchronize()
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        loss_meter.update(loss.item(), target.size(0))
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
    ave_loss =loss_meter.avg
    acc = acc1_meter.avg
    print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
    return ave_loss, acc


# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    val_list = []
    pred_list = []

    for data, target in test_loader:
        for t in target:
            val_list.append(t.data.item())
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        output = model(data)
        loss = criterion_val(output, target)
        _, pred = torch.max(output.data, 1)
        for p in pred:
            pred_list.append(p.data.item())
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))
    acc = acc1_meter.avg
    print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(
        loss_meter.avg, acc, acc5_meter.avg))
    if acc > Best_ACC:
        if isinstance(model, torch.nn.DataParallel):
            torch.save(model.module, file_dir + '/' + 'best.pth')
        else:
            torch.save(model, file_dir + '/' + 'best.pth')
        Best_ACC = acc
    if isinstance(model, torch.nn.DataParallel):
        state = {

            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.module.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    else:
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    return val_list, pred_list, loss_meter.avg, acc


def seed_everything(seed=0):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    file_dir = 'checkpoints/LEMEVIT-small/'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir,exist_ok=True)
    else:
        os.makedirs(file_dir)

    model_lr = 1e-3
    BATCH_SIZE = 16
    EPOCHS = 50
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_amp =True
    use_dp = True
    classes = 4
    resume =None
    CLIP_GRAD = 5.0
    Best_ACC = 0
    use_ema=False
    model_ema_decay=0.9995
    start_epoch=1
    seed=1
    seed_everything(seed)
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
    ])
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=0.1, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=classes)

    dataset_train = datasets.ImageFolder('dataset/train', transform=transform)
    dataset_test = datasets.ImageFolder("dataset/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

    criterion_train = SoftTargetCrossEntropy()
    criterion_val = torch.nn.CrossEntropyLoss()
    model_ft = lemevit_small_v2(pretrained=False)
    num_fr=model_ft.head.in_features
    model_ft.head =nn.Linear(num_fr,classes)
    print(model_ft)
    if resume:
        model=torch.load(resume)
        print(model['state_dict'].keys())
        model_ft.load_state_dict(model['state_dict'],strict = False)
        Best_ACC=model['Best_ACC']
        start_epoch=model['epoch']+1
    model_ft.to(DEVICE)
    optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=40, eta_min=5e-8)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    if torch.cuda.device_count() > 1 and use_dp:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model_ft = torch.nn.DataParallel(model_ft)
    if use_ema:
        model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device=DEVICE,
            resume=resume)
    else:
        model_ema=None

    # 训练与验证
    is_set_lr = False
    log_dir = {}
    train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
    epoch_info = []
    if resume and os.path.isfile(file_dir+"result.json"):
        with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
            logs = json.load(file)
            train_acc_list = logs['train_acc']
            train_loss_list = logs['train_loss']
            val_acc_list = logs['val_acc']
            val_loss_list = logs['val_loss']
            epoch_list = logs['epoch_list']
    for epoch in range(start_epoch, EPOCHS + 1):
        epoch_list.append(epoch)
        log_dir['epoch_list'] = epoch_list
        train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        log_dir['train_acc'] = train_acc_list
        log_dir['train_loss'] = train_loss_list
        if use_ema:
            val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
        else:
            val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)
        log_dir['val_acc'] = val_acc_list
        log_dir['val_loss'] = val_loss_list
        log_dir['best_acc'] = Best_ACC
        with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(log_dir))
        print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
        epoch_info.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        df = pd.DataFrame(epoch_info)
        df.to_excel(file_dir + "/epoch_info.xlsx",index=False)
        with open('epoch_info.txt', 'w') as f:
            for epoch_data in epoch_info:
                f.write(f"Epoch: {epoch_data['epoch']}\n")
                f.write(f"Train Loss: {epoch_data['train_loss']}\n")
                f.write(f"Train Acc: {epoch_data['train_acc']}\n")
                f.write(f"Val Loss: {epoch_data['val_loss']}\n")
                f.write(f"Val Acc: {epoch_data['val_acc']}\n")
                f.write("\n")
        if epoch < 600:
            cosine_schedule.step()
        else:
            if not is_set_lr:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = 1e-6
                    is_set_lr = True
        fig = plt.figure(1)
        plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
        # 显示图例
        plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
        plt.legend(["Train Loss", "Val Loss"], loc="upper right")
        plt.xlabel(u'epoch')
        plt.ylabel(u'loss')
        plt.title('Model Loss ')
        plt.savefig(file_dir + "/loss.png")
        plt.close(1)
        fig2 = plt.figure(2)
        plt.plot(epoch_list, train_acc_list, 'g-', label=u'Train Acc')
        plt.plot(epoch_list, val_acc_list, 'y-', label=u'Val Acc')
        plt.legend(["Train Acc", "Val Acc"], loc="lower right")
        plt.title("Model Acc")
        plt.ylabel("acc")
        plt.xlabel("epoch")
        plt.savefig(file_dir + "/acc.png")
        plt.close(2)

测试代码:

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import recall_score, precision_score, f1_score, accuracy_score

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义自定义数据集
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []

        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 图像转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset_root = 'dataset/sunlight'
dataset = CustomDataset(root_dir=dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# 加载模型
model_path = 'checkpoints/LEMEVIT-small/best.pth'
model = torch.load(model_path)
model.to(device)
model.eval()

# 预测函数
def predict(model, dataloader, device):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_preds), np.array(all_labels)

# 获取预测结果
predictions, true_labels = predict(model, dataloader, device)

# 计算各个指标
recall = recall_score(true_labels, predictions, average='macro')
precision = precision_score(true_labels, predictions, average='macro')
f1 = f1_score(true_labels, predictions, average='macro')
accuracy = accuracy_score(true_labels, predictions)

# 获取类别名称
class_names = sorted(os.listdir(dataset_root))

print(f'Class names: {class_names}')
print(f'Recall: {recall:.4f}')
print(f'Precision: {precision:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'Accuracy: {accuracy:.4f}')

 结果如下:

可以看到,根据实验结果,LEMEVIT-tiny和LEMEVIT-small在不同光照条件下的表现差异明显。LEMEVIT-tiny在所有条件下的指标均略高于LEMEVIT-small,尤其在阴影和正常光照下,显示出更稳定的性能。而在阳光下,两个模型的性能均有所下降,但LEMEVIT-tiny依然保持较高的精度和召回率。总体而言,LEMEVIT-tiny在光照变化下具有更优越的适应性和稳定性,表现出更高的鲁棒性。当然本结果只针对我的数据集起作用,各位可以自己去实验。

列出模型的地址:https://github.com/ViTAE-Transformer/LeMeViT

以上为全部内容!

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值