参考网址: https://zhuanlan.zhihu.com/p/628375255?utm_id=0
https://blog.csdn.net/qq_40905284/article/details/130686913
https://zhuanlan.zhihu.com/p/623257270
BLIP
学习目的:了解模型架构
特点:冻结参数的预训练视觉模型和大型语言模型image encoder LLM
Querying Transformer 从冻结的视觉编码器中引导多模态学习 从冻结的文本编码器中引导多模态学习 从而减少训练参数
Q-Former使用一组可学习的 Query 向量从冻结的视觉编码器中提取视觉特征,并充当视觉编码器和文本编码器之间的桥梁。Q-Former 把关键的视觉信息传递给 LLM。
第一个预训练阶段,Q-Former 学习与文本最相关的视觉表征。
第二个预训练阶段,通过将 Q-Former 的输出连接到冻结的 LLM 来执行视觉语言生成学习,使其输出的视觉表征可以直接由 LLM 解释。这样一来,Q-Former 就可以有效地利用冻结的预训练图像模型和语言模型。
个人理解:给定一个文本 让Q匹配到最相关的图片 传出LLM可以识别的信息 相当于一个桥梁 图像编码器提取视觉特征
Self-Attention 的输入有2个:Queries 和 Text。
第1个 Transformer 子模块:是 Image Transformer,它与图像编码器交互,用于视觉特征提取。输入可学习的 Queries,它们先通过 Self-Attention 建模互相之间的依赖关系,再通过 Cross-Attention 建模 Queries 和图片特征的依赖关系。因为两个 Transformer 的子模块是共享参数的,所以 Queries 也可以与文本输入做交互(feed forward 前馈神经网络)
第2个 Transformer 子模块:是 Text Transformer,它既可以作为文本编码器,也可以充当文本解码器。
Q-Former 连接到冻结参数的图像编码器,并使用图像-文本对进行预训练,那么这一步的目标是训练好 Q-Former,以便 Queries 可以学习到如何更好地结合文本提取图片信息。
ITC对齐图像和文本的表征,互信息最大化
计算 Queries 的输出Z和 Text Transformer 的 [CLS] token 输出 T的对比学习损失。因为Z 包含了多个 Queries 的输出,因此作者首先计算每个 Queries 的输出和T之间的成对相似度,然后选择最高的一个作为最终的图文相似度。不允许相互看到
IGT给定一张图片 训练生成对应的文本描述 提取捕获了所有文本信息的视觉特征
计算Q隐藏T 根据Q输出T T根据当前预测以后的允许 Text 看到 Queries (Queries 里面有视觉信息),同时每个 Text token 只能看到它之前的 Text token (生成式任务的基本做法)。但是不允许 Queries 看到 Text 的信息,只能看到自己的信息。此外作者还将 [CLS] token 替换为一个新的 [DEC] token 作为第一个 Text token 来指示解码任务。
ITM二分类
允许 Text 看到 Queries,同时每个 Text token 只能看到它之前的 Text token 。但是不允许 Queries 看到 Text 的信息,只能看到自己的信息。此外作者还将 [CLS] token 替换为一个新的 [DEC] token 作为第一个 Text token 来指示解码任务。
利用 LLM 的文本生成能力,图像的表征和 Queries 一起送入 Q-Former,得到 Queries 的输出z,经过一个全连接层与 Text token 的维度对齐之后输入给 LLM Decoder。这个 Queries 的输出就蕴含了视觉信息,在输入给 LLM 的时候就充当了 Soft Visual Prompt 的作用。
Queries 在经过了第1阶段的训练之后,已经学习到了如何更好地结合文本提取图片信息,因此它可以有效地将最有用的图片信息提供给 LLM,同时删除不相关的视觉信息,这减少了 LLM 学习视觉语言对齐的负担。
- create dataset、loader
# 创建检索数据集
print("Creating retrieval dataset")
train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
else:
samplers = [None, None, None]
# 创建数据加载器
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
num_workers=[4,4,4],
is_trains=[True, False, False],
collate_fns=[None,None,None])
- model
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
best = 0
best_epoch = 0
- training 分数
for epoch in range(0, config['max_epoch']):
if not args.evaluate:
if args.distributed:
train_loader.sampler.set_epoch(epoch)
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
train_stats = train(model, train_loader, optimizer, epoch, device, config)
score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
if utils.is_main_process():
val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
print(val_result)
if val_result['r_mean']>best:
save_obj = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'config': config,
'epoch': epoch,
}
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
best = val_result['r_mean']
best_epoch = epoch
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
print(test_result)
- 遍历数据集 将数据移动到设备上
for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# 将数据移动到设备上
image = image.to(device,non_blocking=True)
idx = idx.to(device,non_blocking=True)
# 计算alpha
if epoch>0:
alpha = config['alpha']
else:
alpha = config['alpha']*min(1,i/len(data_loader))
loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
loss = loss_ita + loss_itm
optimizer.zero_grad()
loss.backward()
optimizer.step()
metric_logger.update(loss_itm=loss_itm.item())
metric_logger.update(loss_ita=loss_ita.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
大模型在大数据量下 大模型训练耗时所以冻结住 文本和图像编码器冻结住不更新 但是不更新的话 gap太大了 所以引入小的transformer 作为桥梁 把图文gap连接 Q-Former参数少 第一阶段表征学习阶段 生成学习阶段 输入可学习的queries的一个embedding 和文本 作为查询器 和当前文本最相关的图像信息是什么 抹除不相关信息 提取视觉信息(与目标文本相关的)输入大模型LLM 下游任务
表征学习 输入图像 文本 冻结的图像编码器会得到图像的embedding
IGT 计算Q隐藏T 根据Q输出T T根据当前预测以后的
生成 视觉信息经过FC
模型 基于encoder的模型很难做生成任务 没有检索任务encoder-decoder-based
MED生成式语言任务 第一个 cls embedding做图文对比学习
encode decode生成任务因果自注意力 LM从bert的形式换成gpt的形式
人工标注的微调