【VLM】LongClip: Unlocking the Long-Text Capability of CLIP

相关代码: https://github.com/beichenzbc/Long-CLIP/blob/3966af9ae9331666309a22128468b734db4672a7/model/model_longclip.py#L243

  • clip-based的局限性
    a)文本模型的最大长度为77, 限制了长文本的表达
    b)导致细粒度信息丢失

  • LongClip提出高效的微调方案(支持长文本+改进模型对属性间关系的理解),即插即用
    a)KPS(Knowledge-Preserved Stretching): 引入新的可学习参数posisitonal_embedding_new,保留前20个pos_emb不变(不需要进行梯度更新), 对后(248-20)=228个pos_emb进行插值
    20:当文本长度超过20时, 模型精度增长缓慢, 对于过长的文本,clip无法更有效地利用额外的信息

    # # 插值代码
 	positional_embedding_pre = model.positional_embedding.type(model.dtype)
            
    length, dim = positional_embedding_pre.shape
    keep_len = 20
    posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype)
    for i in range(keep_len):
        posisitonal_embedding_new[i] = positional_embedding_pre[i]
    for i in range(length-1-keep_len):
        posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len]
        posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4
        posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4
        posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4

    posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
    posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
    posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
    posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
            
    positional_embedding_res = posisitonal_embedding_new.clone()
            
    model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False)
    model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True)

b)PCM(Primary Component Matching): 对齐—>细粒度图像特征+长文本; 对齐—>粗粒度信息(细粒度图像特征经PCA后获取)+短文本
若仅对长文本微调,会导致在短文本上性能下降,

在这里插入图片描述
By doing so, we require the model not only to capture detailed attributes but also to discern and prioritize the importance of different attributes.

def encode_text(self, text): 
     x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
     
     x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 
     
     x = x.permute(1, 0, 2)  # NLD -> LND
     x = self.transformer(x)
     x = x.permute(1, 0, 2)  # LND -> NLD
     x = self.ln_final(x).type(self.dtype)

     # x.shape = [batch_size, n_ctx, transformer.width]
     # take features from the eot embedding (eot_token is the highest number in each sequence)
     x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

     return x
     
def forward(self, image, text_long,text_short,rank):
        image_features_long = self.encode_image(image)
        text_features_long = self.encode_text(text_long)
        text_features_short = self.encode_text(text_short)

        # normalized features
        image_features_long = image_features_long / image_features_long.norm(dim=1, keepdim=True)
        text_features_long = text_features_long / text_features_long.norm(dim=1, keepdim=True)
        text_features_short = text_features_short / text_features_short.norm(dim=1, keepdim=True)
        image_features_short = self.PCA(image_features_long, 32)
            
        image_feat_all_long = torch.cat(torch.distributed.nn.all_gather(image_features_long), dim=0)#gather with grad
        image_features_all_short = torch.cat(torch.distributed.nn.all_gather(image_features_short), dim=0)
        text_feat_all_long = torch.cat(torch.distributed.nn.all_gather(text_features_long), dim=0)
        text_feat_all_short = torch.cat(torch.distributed.nn.all_gather(text_features_short), dim=0)
        
        sim_i2tl = torch.matmul(image_features_long, text_feat_all_long.T)
        sim_tl2i = torch.matmul(image_feat_all_long, text_features_long.T)
        sim_tl2i = sim_tl2i.T

        sim_i2ts = torch.matmul(image_features_short, text_feat_all_short.T)
        sim_ts2i = torch.matmul(image_features_all_short, text_features_short.T)
        sim_ts2i = sim_ts2i.T
        
        sim_i2tl = self.logit_scale.exp() * sim_i2tl
        sim_tl2i = self.logit_scale.exp() * sim_tl2i

        sim_i2ts = self.logit_scale.exp() * sim_i2ts
        sim_ts2i = self.logit_scale.exp() * sim_ts2i

        bs = image.size(0)
        targets = torch.linspace(rank * bs,rank * bs + bs - 1, bs, dtype=torch.long).to(image.device)
        
        loss_itcl = (
                F.cross_entropy(sim_i2tl, targets, label_smoothing=0.1)
                + F.cross_entropy(sim_tl2i, targets, label_smoothing=0.1)
            ) / 2
        
        loss_itcs = (
                F.cross_entropy(sim_i2ts, targets, label_smoothing=0.1)
                + F.cross_entropy(sim_ts2i, targets, label_smoothing=0.1)
            ) / 2
        return loss_itcl, loss_itcs
  • 说明
    1)mask1–>只对前20个pos_emb可见, mask2–>负责后(248-20=)228个pos_emb
    2)对细粒度视觉特征进行PCA,保证提取细粒度对齐的同时,也要学习主要概念
    在这里插入图片描述
# rewrite PCA to avoid inf
def PCA(self, input_tensor, PCA_dim):
    # 计算均值
    mean = torch.mean(input_tensor, dim=0)
    # 去均值
    X_centered = input_tensor - mean.unsqueeze(0)
    X_centered = X_centered.float()

    # 使用SVD而不是eig来计算主成分
    U, S, Vt = torch.linalg.svd(X_centered, full_matrices=False)
    principal_components = Vt.T[:, :PCA_dim]
    
    # 转换到新的维度
    X_transformed = torch.mm(X_centered, principal_components)
    # 恢复到原始空间
    X_reversed = torch.mm(X_transformed, principal_components.T)
    X_reversed += mean

    return X_reversed
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值