针对中文CLIP模型训练失效的解决方案记录

针对CLIP模型训练失效的解决方案记录

前言

最近在训练中文CLIP模型的时候,遇到了不少困扰,好在经过不断的尝试和对源代码的解读,最终还是得到了正确的训练结果,特此分享给大家,少走弯路。

遇到的问题

我的目标:通过CLIP模型训练一个四分类模型

预期流程:用户输入中文描述词,输入到CLIP模型中,返回类别

数据集情况:四分类数据集,包含四条描述词(即四个类别)和将近六千幅图像

训练参数:默认参数

难题:训练集的Loss持续振荡,准确率没有提升;验证集Loss和准确率无明显变化

问题分析

1. 数据集层面

a. 分析

查看官方样例中的Muge数据集样本分布情况,发现与我的数据集分布差异非常大,Muge数据集基本上一个描述词会对应一张图片,而我是一个描述词对应上千张图片。

b. 尝试

参考Muge数据集,遍历每一张图片,为此按照 类别+辅助词 的形式构成每幅图片的描述词,辅助词类似于好用、坚固、可靠等构成的列表,每张图片的描述词随机,目标是增加描述词的数量。

c. 结果

训练集:loss呈下降趋势,准确率在第10个Epoch明显上升

验证集:Loss不断升高,准确率无明显变化

出现了典型的过拟合现象,此路不通。

2. 算法层面

a. 分析

我尝试用默认参数运行了Muge数据集,发现参数正常,模型能够收敛,证明模型没有问题;同时尝试修改了很多默认参数,例如学习率、正则化项等,皆没有显著影响。

P.S:一般模型训练出现问题,99%的问题关键点都出现在数据或者数据读取上。

因此翻看模型的数据读取、预处理、训练、评估等步骤的源代码,果然发现了问题所在。

with torch.no_grad():
    for i in range(dataloader.num_batches):
        batch = next(data_iter)
        images, texts, eos_indices = batch

        images = images.cuda(args.local_device_rank, non_blocking=True)
        texts = texts.cuda(args.local_device_rank, non_blocking=True)
        eos_indices = eos_indices.cuda(args.local_device_rank, non_blocking=True)

        image_features, text_features, logit_scale = model(images, texts)
        all_image_features.append(image_features)
        all_text_features.append(text_features)
        logit_scale = logit_scale.mean()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        ground_truth = torch.arange(len(images)).long()
        ground_truth = ground_truth.cuda(args.local_device_rank, non_blocking=True)
        total_loss = (
            loss_img(logits_per_image, ground_truth)
            + loss_txt(logits_per_text, ground_truth)
        ) / 2

        batch_size = len(images)
        cumulative_loss += total_loss * batch_size
        num_elements += batch_size
        
        cumulative_i2t_acc += ((logits_per_image.argmax(-1) == ground_truth).sum()).float()
        cumulative_t2i_acc += (logits_per_text.argmax(-1) == ground_truth).sum().float()

这段代码出现在evaluate函数中,用于计算损失值,逐行查看

在这里插入图片描述

针对image2txt任务,对于一个正确训练的clip模型,矩阵应该如下,很容易看出,对角线上有着更高的置信度。我们的GroundTruth由ground_truth = torch.arange(len)生成,即对角线数值的索引。交叉熵会按照红框的形状依次进行计算。

但由于我们的数据一个描述词对应多个图片,就会出现如下情况:
在这里插入图片描述

不难发现,一个batch中出现了两条苹果词条,这两个都是正确的,但我们的groundtruth不知道发生了什么,它认为只有对角线上的索引是正确的。因此loss无法正常计算,模型学习不到东西,或者说,明明学会了分类,但老师却告诉他分的不对,给模型也整不会了。

因此这个问题就跟batch-size有关,理论上batch-size设置的越大,学习能力也就越差。

但仅仅调小batch-size也是不行的,针对我们的四分类任务,batch-size设置为4,仍然会有非常大的可能包含重复的词条,如果设置为1的话,不论分类结果是否正确,都会被视作是正确的,因此batch-size万万不可设置为1。

b. 尝试

既然弄清了问题所在,那么就要强制模型每个batch里面必须包含不重样的图文对,这样就可以简化为分类任务,简单来说,可以这样实现:

├─类别A
  ---1.jpg
  ---2.jpg
  ...
├─类别B
  ---1.jpg
  ---2.jpg
  ...
├─类别C
  ---1.jpg
  ---2.jpg
  ...
└─类别D
  ---1.jpg
  ---2.jpg
  ...

batch-size固定为4,第一个batch,读取类别A的1.jpg,类别B的1.jpg,类别C的1.jpg,类别D的1.jpg;第二个batch,读取类别A的2.jpg,类别B的2.jpg,类别C的2.jpg,类别D的2.jpg…以此类推

c. 结果

训练集:Loss持续下降,且很快收敛;精度很快上升,且很快保持在100

验证集:同训练集

训练成功,至此我们就解决了CLIP模型的训练问题。

总结

当个炼丹师也不容易,如果可以的话,尽可能的将自己的数据集和官方demo使用的数据集的规模、结构、组织形式等全方位保持一致,这样会少走不少弯路;同时遇到训练精度问题,还是多从数据上找原因,对于已发表文章的模型来说,数据决定了下限,参数再怎么改也救不了差数据。

当个炼丹师也不容易,如果可以的话,尽可能的将自己的数据集和官方demo使用的数据集的规模、结构、组织形式等全方位保持一致,这样会少走不少弯路;同时遇到训练精度问题,还是多从数据上找原因,对于已发表文章的模型来说,数据决定了下限,参数再怎么改也救不了差数据。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值