提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
在用CLIP作为encoder的时候,只保留CLIP的权重,后续网络正常训练
一、在model里将网络搭建好
import torch
import clip
from PIL import Image
from DEQfusion import DEQFusion
import torch.nn as nn
import torch
import torch.nn.functional as F
class model_enc(nn.Module):
def __init__(self):
super(model_enc, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda" if torch.cuda.is_available() else "cpu")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.num_modals = 2
self.channel_dim = 512
self.fusion = DEQFusion(self.channel_dim, self.num_modals)
def forward(self, image, text):
image = self.preprocess(Image.open(image)).unsqueeze(0).to(self.device)
text = clip.tokenize([text]).to(self.device)
features = []
image_features = self.model.encode_image(image)
features.append(image_features)
text_features = self.model.encode_text(text)
features.append(text_features)
fused_feat, jacobian_loss, trace = self.fusion(features)
return fused_feat, jacobian_loss, trace
if __name__ == "__main__":
enc = model_enc().to("cuda")
image = "CLIP.png"
text = "C opens a bag"
fused_feat, jacobian_loss, trace = enc(image, text)
print(fused_feat.shape)
这里将image和text通过encoder后用一个DEQfusion将其融合
二、冻结CLIP
class model_enc(nn.Module):
def __init__(self):
super(model_enc, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda" if torch.cuda.is_available() else "cpu")
for p in self.parameters():
p.requires_grad = False
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.num_modals = 2
self.channel_dim = 512
self.fusion = DEQFusion(self.channel_dim, self.num_modals)