一、背景
在大模型应用场景中,多模态学习是一个重要的研究方向,它涉及到将不同模态的信息(如文本、图像、音频等)进行有效整合,以提高模型的理解和表达能力。图文匹配识别作为多模态学习中的一个关键问题,要求模型能够理解图像内容并将其与相应的文本描述进行匹配。
二、任务要求、数据说明
给了许多的图片、文本,要求参赛者进行两两匹配
训练集中图片、文本有给对应的匹配关系,如下
测试集就没有了,只给了训练集的text、image两部分信息,需要选手自己把这些算出来
三、思路
先把训练集用你选用的模型给finetune一遍最好,要是嫌麻烦不训练也行,因为现在的很多模型在这方面已经做得很好了。然后调clip模型,因为这个模型在图文上都支持的非常好,而且在同一空间下,计算万emb后,可以直接用余弦相似度来进行一一匹配
四、代码
1、加载数据
path = "/mnt/workspace/dataset/test_candidate_image.csv"
test_candidate_image = pd.read_csv(path)
print(test_candidate_image.head())
print(len(test_candidate_image))
images = list(test_candidate_image['image_name'])
print(len(images))
print(images[:10])
path = "/mnt/workspace/dataset/test_candidate_text.csv"
test_candidate_text = pd.read_csv(path)
print(test_candidate_text.head())
print(len(test_candidate_text))
titles = list(test_candidate_text['text'])
print(len(titles), titles[:10])
title_index = list(test_candidate_text['idx'])
print(len(title_index), title_index[:10])
2、计算两两相似度
import sys
import glob
from PIL import Image
sys.path.append('..')
from similarities import ImageHashSimilarity, SiftSimilarity, ClipSimilarity
m = ClipSimilarity(model_name_or_path="/mnt/workspace/OFA-Sys/chinese-clip-vit-base-patch16")
print("m=",m)
print("=======================imgaes begin=======================")
image_fps = ['/mnt/workspace/dataset/images/'+i for i in images]
print(len(image_fps), image_fps[:4])
imgs = [Image.open(i) for i in image_fps]
print(len(imgs), imgs[:4])
print("=======================imgaes end=======================")
print("=======================titles end=======================")
print(len(titles), titles[:4])
print("=======================titles end=======================")
sim_scores = m.similarity(titles, imgs)
3、计算每个title最匹配的img
#### 下面这个是title作为第一个参数传进去
res = []
for i in range(len(title_index)):
title_index_one = title_index[i]
max_index = max(enumerate(sim_scores[i]), key=lambda x: x[1])[0]
image_path_one = image_fps[max_index].split("/")[-1]
# print('title_index_one=',title_index_one, "image_path_one=",image_path_one, 'title=',titles[i])
res.append([title_index_one,image_path_one])
if(i % 1000 == 0):
# if(i == 10):
print('i=',i)
# break
print("len(res)=",len(res), ' res[:4]=', res[:4])
res_df = pd.DataFrame(res, columns=['idx','image_name'])
res_df.to_csv('./res_df_title_ok_english.csv', index=False, encoding='utf-8')
** 更多在公众号:大模型软硬件