针对CLIP模型训练失效的解决方案记录
前言
最近在训练中文CLIP模型的时候,遇到了不少困扰,好在经过不断的尝试和对源代码的解读,最终还是得到了正确的训练结果,特此分享给大家,少走弯路。
遇到的问题
我的目标:通过CLIP模型训练一个四分类模型
预期流程:用户输入中文描述词,输入到CLIP模型中,返回类别
数据集情况:四分类数据集,包含四条描述词(即四个类别)和将近六千幅图像
训练参数:默认参数
难题:训练集的Loss持续振荡,准确率没有提升;验证集Loss和准确率无明显变化
问题分析
1. 数据集层面
a. 分析
查看官方样例中的Muge数据集样本分布情况,发现与我的数据集分布差异非常大,Muge数据集基本上一个描述词会对应一张图片,而我是一个描述词对应上千张图片。
b. 尝试
参考Muge数据集,遍历每一张图片,为此按照 类别+辅助词
的形式构成每幅图片的描述词,辅助词
类似于好用、坚固、可靠
等构成的列表,每张图片的描述词随机,目标是增加描述词的数量。
c. 结果
训练集:loss呈下降趋势,准确率在第10个Epoch明显上升
验证集:Loss不断升高,准确率无明显变化
出现了典型的过拟合现象,此路不通。
2. 算法层面
a. 分析
我尝试用默认参数运行了Muge数据集,发现参数正常,模型能够收敛,证明模型没有问题;同时尝试修改了很多默认参数,例如学习率、正则化项等,皆没有显著影响。
P.S:一般模型训练出现问题,99%的问题关键点都出现在数据或者数据读取上。
因此翻看模型的数据读取、预处理、训练、评估等步骤的源代码,果然发现了问题所在。
with torch.no_grad():
for i in range(dataloader.num_batches):
batch = next(data_iter)
images, texts, eos_indices = batch
images = images.cuda(args.local_device_rank, non_blocking=True)
texts = texts.cuda(args.local_device_rank, non_blocking=True)
eos_indices = eos_indices.cuda(args.local_device_rank, non_blocking=True)
image_features, text_features, logit_scale = model(images, texts)
all_image_features.append(image_features)
all_text_features.append(text_features)
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
ground_truth = torch.arange(len(images)).long()
ground_truth = ground_truth.cuda(args.local_device_rank, non_blocking=True)
total_loss = (
loss_img(logits_per_image, ground_truth)
+ loss_txt(logits_per_text, ground_truth)
) / 2
batch_size = len(images)
cumulative_loss += total_loss * batch_size
num_elements += batch_size
cumulative_i2t_acc += ((logits_per_image.argmax(-1) == ground_truth).sum()).float()
cumulative_t2i_acc += (logits_per_text.argmax(-1) == ground_truth).sum().float()
这段代码出现在evaluate
函数中,用于计算损失值,逐行查看
针对image2txt任务,对于一个正确训练的clip模型,矩阵应该如下,很容易看出,对角线上有着更高的置信度。我们的GroundTruth由ground_truth = torch.arange(len)
生成,即对角线数值的索引。交叉熵会按照红框的形状依次进行计算。
但由于我们的数据一个描述词对应多个图片,就会出现如下情况:
不难发现,一个batch中出现了两条苹果词条,这两个都是正确的,但我们的groundtruth不知道发生了什么,它认为只有对角线上的索引是正确的。因此loss无法正常计算,模型学习不到东西,或者说,明明学会了分类,但老师却告诉他分的不对,给模型也整不会了。
因此这个问题就跟batch-size有关,理论上batch-size设置的越大,学习能力也就越差。
但仅仅调小batch-size也是不行的,针对我们的四分类任务,batch-size设置为4,仍然会有非常大的可能包含重复的词条,如果设置为1的话,不论分类结果是否正确,都会被视作是正确的,因此batch-size万万不可设置为1。
b. 尝试
既然弄清了问题所在,那么就要强制模型每个batch里面必须包含不重样的图文对,这样就可以简化为分类任务,简单来说,可以这样实现:
├─类别A
---1.jpg
---2.jpg
...
├─类别B
---1.jpg
---2.jpg
...
├─类别C
---1.jpg
---2.jpg
...
└─类别D
---1.jpg
---2.jpg
...
batch-size固定为4,第一个batch,读取类别A的1.jpg,类别B的1.jpg,类别C的1.jpg,类别D的1.jpg;第二个batch,读取类别A的2.jpg,类别B的2.jpg,类别C的2.jpg,类别D的2.jpg…以此类推
c. 结果
训练集:Loss持续下降,且很快收敛;精度很快上升,且很快保持在100
验证集:同训练集
训练成功,至此我们就解决了CLIP模型的训练问题。
总结
当个炼丹师也不容易,如果可以的话,尽可能的将自己的数据集和官方demo使用的数据集的规模、结构、组织形式等全方位保持一致,这样会少走不少弯路;同时遇到训练精度问题,还是多从数据上找原因,对于已发表文章的模型来说,数据决定了下限,参数再怎么改也救不了差数据。
当个炼丹师也不容易,如果可以的话,尽可能的将自己的数据集和官方demo使用的数据集的规模、结构、组织形式等全方位保持一致,这样会少走不少弯路;同时遇到训练精度问题,还是多从数据上找原因,对于已发表文章的模型来说,数据决定了下限,参数再怎么改也救不了差数据。