-
clip-based的局限性
a)文本模型的最大长度为77, 限制了长文本的表达
b)导致细粒度信息丢失 -
LongClip提出高效的微调方案(支持长文本+改进模型对属性间关系的理解),即插即用
a)KPS(Knowledge-Preserved Stretching): 引入新的可学习参数posisitonal_embedding_new,保留前20个pos_emb不变(不需要进行梯度更新), 对后(248-20)=228个pos_emb进行插值
20:当文本长度超过20时, 模型精度增长缓慢, 对于过长的文本,clip无法更有效地利用额外的信息
# # 插值代码
positional_embedding_pre = model.positional_embedding.type(model.dtype)
length, dim = positional_embedding_pre.shape
keep_len = 20
posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype)
for i in range(keep_len):
posisitonal_embedding_new[i] = positional_embedding_pre[i]
for i in range(length-1-keep_len):
posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len]
posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
positional_embedding_res = posisitonal_embedding_new.clone()
model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False)
model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True)
b)PCM(Primary Component Matching): 对齐—>细粒度图像特征+长文本; 对齐—>粗粒度信息(细粒度图像特征经PCA后获取)+短文本
若仅对长文本微调,会导致在短文本上性能下降,
By doing so, we require the model not only to capture detailed attributes but also to discern and prioritize the importance of different attributes.
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text_long,text_short,rank):
image_features_long = self.encode_image(image)
text_features_long = self.encode_text(text_long)
text_features_short = self.encode_text(text_short)
# normalized features
image_features_long = image_features_long / image_features_long.norm(dim=1, keepdim=True)
text_features_long = text_features_long / text_features_long.norm(dim=1, keepdim=True)
text_features_short = text_features_short / text_features_short.norm(dim=1, keepdim=True)
image_features_short = self.PCA(image_features_long, 32)
image_feat_all_long = torch.cat(torch.distributed.nn.all_gather(image_features_long), dim=0)#gather with grad
image_features_all_short = torch.cat(torch.distributed.nn.all_gather(image_features_short), dim=0)
text_feat_all_long = torch.cat(torch.distributed.nn.all_gather(text_features_long), dim=0)
text_feat_all_short = torch.cat(torch.distributed.nn.all_gather(text_features_short), dim=0)
sim_i2tl = torch.matmul(image_features_long, text_feat_all_long.T)
sim_tl2i = torch.matmul(image_feat_all_long, text_features_long.T)
sim_tl2i = sim_tl2i.T
sim_i2ts = torch.matmul(image_features_short, text_feat_all_short.T)
sim_ts2i = torch.matmul(image_features_all_short, text_features_short.T)
sim_ts2i = sim_ts2i.T
sim_i2tl = self.logit_scale.exp() * sim_i2tl
sim_tl2i = self.logit_scale.exp() * sim_tl2i
sim_i2ts = self.logit_scale.exp() * sim_i2ts
sim_ts2i = self.logit_scale.exp() * sim_ts2i
bs = image.size(0)
targets = torch.linspace(rank * bs,rank * bs + bs - 1, bs, dtype=torch.long).to(image.device)
loss_itcl = (
F.cross_entropy(sim_i2tl, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_tl2i, targets, label_smoothing=0.1)
) / 2
loss_itcs = (
F.cross_entropy(sim_i2ts, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_ts2i, targets, label_smoothing=0.1)
) / 2
return loss_itcl, loss_itcs
- 说明
1)mask1–>只对前20个pos_emb可见, mask2–>负责后(248-20=)228个pos_emb
2)对细粒度视觉特征进行PCA,保证提取细粒度对齐的同时,也要学习主要概念
# rewrite PCA to avoid inf
def PCA(self, input_tensor, PCA_dim):
# 计算均值
mean = torch.mean(input_tensor, dim=0)
# 去均值
X_centered = input_tensor - mean.unsqueeze(0)
X_centered = X_centered.float()
# 使用SVD而不是eig来计算主成分
U, S, Vt = torch.linalg.svd(X_centered, full_matrices=False)
principal_components = Vt.T[:, :PCA_dim]
# 转换到新的维度
X_transformed = torch.mm(X_centered, principal_components)
# 恢复到原始空间
X_reversed = torch.mm(X_transformed, principal_components.T)
X_reversed += mean
return X_reversed