CoCoOp - CLIP 自适应Prompt工程 【二】

本文介绍了CLIP在PromptEngineering中的CoOp和CoCoOp技术,提出CoCoOp通过引入Meta-Net和条件提示词学习,有效解决了CoOp导致的泛化性折损问题,实验证明CoCoOp在保持Prompt优化力的同时提高了泛化力。
摘要由CSDN通过智能技术生成

承接上期,主要还是讲解了 CLIP 在 Prompt Engineering 方面的工作的阐述:

下面会有如下几篇博文进行  Prompt Engineering 方面的工作的阐述:

         【1】 CoOp - CLIP 自适应 Prompt 工程

         【2】 CoCoOp - CLIP 自适应 Prompt 工程 [本文]

一. 综述

由上文我们可以得知,由于 Prompt 工程 需要耗费大量的人力在 相关业务数据的适配上。于是 Kaiyang Zhou 将传统的机器学习的经验引入到了 Prompt 工程中。根据梯度迭代自适应地学习到 Prompt 词汇。但由于泛化性的原因,Prompt 后模型泛化性明显受到折损。所以说 Kaiyang Zhou 根据这个情况,提出了 条件选择性 的模型解决方案,来折衷解决这个问题。

二. 原因思考

书接上文,作者对于经过 CoOp 学习过的数据集, 如果说是 CoOp Few-Shot 过的 标签,会带来很好的效果。但是对于 没有 Few-Shot 过的标签, 如图 (b) 所示, CoOp的效果会相比 Zero-Shot 要变差 [说明比Base能力要弱],说明 CoOp会导致 CLIP 泛化性折损。

对结果进行分析上,我们回归到 FewShot 的原理上进行阐述: 

Few-Shot 的过程,实际上是对见过标签的 [过拟合] 优化。模型的权重在其中的变化会 "破坏" 掉原本初始情况下最优的泛化性能力。因本身 CLIP 模型是在大量的数据中训练到的最好的权重。Few-Shot的样本 一定在分布上会有别于 CLIP本身 4亿+左右的样本分布[来源]。 导致 CLIP 权重被破坏掉,泛化性变差。

然后理解到这个方式在上升到 CoOp 的设计:  

我们由CoOp的模型设计可以得知,由于 CoOp 的静态设计: "Learnable Context" , CoOp 在一开始学习的时候就是 "固定" 且 "只" 针对用于训练的数据集标签任取最多16个样本进行独立优化。最终训练的可学习向量也是训练后就固定不变的。相关的数据量 & 方法的操作和 Few-Shot 一模一样,所以这也就是 CoOp 为何泛化性会折损的原因。

因此为了解决如上的问题,Kaiyang Zhou  提出来了 “有条件的提示词学习” [Conditional Prompt Learning] 的思路,去解决这个问题,于是便产生了新的工作 CoCoOp。

三. CoCoOp 介绍

整个 CoCoOp 的结构如下所示:

作者为了引入  “有条件的提示词学习” [Conditional Prompt Learning] 的思路,主要做出如下改动:

【1】在模型的 图像Encoder 中引用了一个相对轻量级的网络: Meta-Net。将图像Encoder 的输出转译为一个特殊的 token π [meta token],以此来表征每一个图片的条件输出值。

【2】 将 Meta-Net 得到的 meta token 直接作为偏置项输出给 可学习prompt向量 [context tokens](也就是CoOp工作本身),  vm(x) = vm + π ,以此将 Conditional 产生的偏置量作用于 Prompt词。

四. 效果阐述

论文对于 CoCoOp 的效果评测与CoOp的评测类似,在训练时候限定了 可学习prompt向量 长度为 4。基类采用 16个样本的 Few-Shot 来进行递归迭代。 基于 11 个公开数据集,且用了三个评估角度去评估,结果非常Solid:
  【1】 基于同一个数据集, CoCoOp 与 Base Model 在 训练用的基类 & 新类下的效果, 如下图所示:

图中可以看出,对于同一个数据集, CoOp 的优化会因为泛化能力的灾难性折损下 变得效果欠佳。然而 CoCoOp 尽量保有 CoOp 在基类 标签的效果【保持Promp优化力】的同时,能在 新类中尽量跟上 CLIP 本身的能力【保持泛化力】。 在 谐波平均值H [Harmonic Mean] 中基本上做到了几个方法最高。

【2】在不同的数据集中,利用一个基数据集 下的训练去验证一个全新的测试数据集下的效果,
此处中论文利用了 ImageNet [每个Class 16个数据 Few-Shot] 做基数据集, 其他的所有数据集为 验证数据集。

与 CoOp 做对比, 虽然 CoCoOp 相对 CoOp 在 基数据集中效果有所折损, 但是在剩余的几乎所有的数据集中CoCoOp 效果都比 CoOp 好,证明其泛化能力的优秀。 

【3】在数据来源一致, 但是领域不同的数据集中验证效果,证明且领域迁移能力: 

此处中同样引用了 ImageNet 作为Base数据集,然后选用与 ImageNet 来源一样,但是图片数据领域完全不同的 ImageNet-Sketch&A&R 等数据集来进行验证,效果如下:

与 CoOp 和 CLIP, 在领域迁移数据集 [数据集类似,但是隶属不同的数据集], CoCoOp 还是在基本所有的数据集中 都好于 CoOp与 CLIP

五. 消融性分析

[1] 与 CoOp 类似,首先论文作者判断 随机初始化 & 预训练模型初始化 对于结果的影响. 如下图(a) 所示,作者利用 embedding 的初始化 和 利用 零均值高斯随机初始化做对比。  可以发现与 CoOp 不同的是, 适当的 初始化 会让模型的表现更好。

[2] 对于 可训练向量长度的分析,分别以 4,8,16 三个大小来对效果进行分析,如下图(b) 所示, 最终发现长度为8 的 可训练prompt 向量 是对整体效果最好的长度设计。

[3] 同时作者验证了 CoCoOp 的效果更好是否原因为 更大的参数量, 以此说明 Meta-Net 本身并没有其他效果,因此 作者将 CoOp 的 可训练长度增大到60,以此保证 当前 CoOp的参数量与 CoCoOp 大体一致。 结果图中可以展示,其实参数量并非真实原因。

六. 代码实现

作者利用 OpenAI 开源的CLIP 作为依赖库, 引用了部分自己 定义的 训练&测试类 等来进行论文复现: 论文代码地址
我们需将其转换为 Chinese_CLIP 所支持的造型后, 进行使用。 
Chinese_CLIP 与 CLIP 在一些算法的实现上有所不同 [包括经过 文本塔Bert 内部处理 & 计算相似度等] Chinese_Clip代码地址

下面为基于 Chinese_CLIP 的部分重要代码改动:

# ~/chinese_clip/cn_clip/clip/model.py 
.....
class CLIP(nn.Module):
    def __init__(self,...):
        
        ....
        
        ## CoCoOp 相关初始化
        # [1] Prompt Context init
        ctx_init = "这是一张"
        prompt = tokenize([ctx_init], context_length=len(ctx_init) + 2)
        embedding = self.bert.embeddings(prompt).type(self.dtype)
        ctx_vectors = embedding[0,  1 : 1 + n_ctx, :]
        
        ....
        
        # Bert 标识位 [SOS & EOS] 不计入训练
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS        
        ctx_vectors = embedding[0,  1 : 1 + n_ctx, :]
        self.ctx = nn.Parameter(ctx_vectors)
        # 偏置后 Prompt 词向量
        self.parsed_ctx = torch.empty(len(ctx_init), self.bert.text_hidden_size, dtype=dtype)
        
        # [2] Meta Net init [瓶颈层]
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(self.visual.output_dim, self.visual.output_dim// 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(self.visual.output_dim// 16, ctx_dim)) 
        ]))               
        
        ...
        
        # [3] MetaNet 前向传播 & 偏置修正
        def meta_net_forward(self, im_features):
            ctx = self.ctx                     # (n_ctx, ctx_dim)
            bias = self.meta_net(im_features)  # (batch, ctx_dim)
            bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
            ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)  
            
            ctx_shifted = ctx + bias           # (batch, n_ctx, ctx_dim)
            
            return ctx_shifted

        ...
                
        # CoCoOp 前向传播
        def forward(self, image, text, mask_ratio=0):
            # Prompt 词生成
            image_features = self.encode_image(image, mask_ratio)
            self.parsed_ctx = self.meta_net_forward(image_features)
            # 文本塔输入
            text_features = self.encode_text(text, prompt_parsed = self.parsed_ctx)
            # 相似度计算
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            return image_features, text_features, self.logit_scale.exp()                                              
# ~/chinese_clip/cn_clip/clip/modeling_bert.py 
.....
class BertModel(BertPreTrainedModel):
    ....
     def forward(self, ....):
        ....
        # CoCoOp Bert修正 
        if input_prompt_embedding != None:
            first_token_id = input_ids[0, 0].item() 
            if first_token_id == _tokenizer.vocab["[CLS]"]: # 删除 原始文本 开头标识符
                input_ids = input_ids[: ,1:]
            # Attention MASK 模块修正    
            tmp_input_ids = torch.cat((input_prompt_index, input_ids), dim=1)
            pad_index = _tokenizer.vocab['[PAD]']
            attention_mask = tmp_input_ids.ne(pad_index).type(data_dtype)
      
        ....
        # 非 Prompt 词 走 Bert 预训练 word2Vec
        embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
        # Concat Prompt词 生成最终 文本序列        
        prompt_embedding = torch.cat((input_sos_embedding, input_prompt_embedding), dim = 1)
        embedding_output = torch.cat((prompt_embedding, embedding_output), dim = 1)   
        
        ....
        
        # 进入 Bert 文本塔
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)                                

  • 31
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值