BLIP学习笔记

参考网址: https://zhuanlan.zhihu.com/p/628375255?utm_id=0
https://blog.csdn.net/qq_40905284/article/details/130686913
https://zhuanlan.zhihu.com/p/623257270

BLIP
学习目的:了解模型架构
特点:冻结参数的预训练视觉模型和大型语言模型image encoder LLM
Querying Transformer 从冻结的视觉编码器中引导多模态学习 从冻结的文本编码器中引导多模态学习 从而减少训练参数
在这里插入图片描述
Q-Former使用一组可学习的 Query 向量从冻结的视觉编码器中提取视觉特征,并充当视觉编码器和文本编码器之间的桥梁。Q-Former 把关键的视觉信息传递给 LLM。
第一个预训练阶段,Q-Former 学习与文本最相关的视觉表征。
第二个预训练阶段,通过将 Q-Former 的输出连接到冻结的 LLM 来执行视觉语言生成学习,使其输出的视觉表征可以直接由 LLM 解释。这样一来,Q-Former 就可以有效地利用冻结的预训练图像模型和语言模型。
个人理解:给定一个文本 让Q匹配到最相关的图片 传出LLM可以识别的信息 相当于一个桥梁 图像编码器提取视觉特征
在这里插入图片描述
Self-Attention 的输入有2个:Queries 和 Text。
第1个 Transformer 子模块:是 Image Transformer,它与图像编码器交互,用于视觉特征提取。输入可学习的 Queries,它们先通过 Self-Attention 建模互相之间的依赖关系,再通过 Cross-Attention 建模 Queries 和图片特征的依赖关系。因为两个 Transformer 的子模块是共享参数的,所以 Queries 也可以与文本输入做交互(feed forward 前馈神经网络)
第2个 Transformer 子模块:是 Text Transformer,它既可以作为文本编码器,也可以充当文本解码器。
Q-Former 连接到冻结参数的图像编码器,并使用图像-文本对进行预训练,那么这一步的目标是训练好 Q-Former,以便 Queries 可以学习到如何更好地结合文本提取图片信息。
ITC对齐图像和文本的表征,互信息最大化
计算 Queries 的输出Z和 Text Transformer 的 [CLS] token 输出 T的对比学习损失。因为Z 包含了多个 Queries 的输出,因此作者首先计算每个 Queries 的输出和T之间的成对相似度,然后选择最高的一个作为最终的图文相似度。不允许相互看到
IGT给定一张图片 训练生成对应的文本描述 提取捕获了所有文本信息的视觉特征
计算Q隐藏T 根据Q输出T T根据当前预测以后的允许 Text 看到 Queries (Queries 里面有视觉信息),同时每个 Text token 只能看到它之前的 Text token (生成式任务的基本做法)。但是不允许 Queries 看到 Text 的信息,只能看到自己的信息。此外作者还将 [CLS] token 替换为一个新的 [DEC] token 作为第一个 Text token 来指示解码任务。
ITM二分类
允许 Text 看到 Queries,同时每个 Text token 只能看到它之前的 Text token 。但是不允许 Queries 看到 Text 的信息,只能看到自己的信息。此外作者还将 [CLS] token 替换为一个新的 [DEC] token 作为第一个 Text token 来指示解码任务。
在这里插入图片描述

利用 LLM 的文本生成能力,图像的表征和 Queries 一起送入 Q-Former,得到 Queries 的输出z,经过一个全连接层与 Text token 的维度对齐之后输入给 LLM Decoder。这个 Queries 的输出就蕴含了视觉信息,在输入给 LLM 的时候就充当了 Soft Visual Prompt 的作用。
Queries 在经过了第1阶段的训练之后,已经学习到了如何更好地结合文本提取图片信息,因此它可以有效地将最有用的图片信息提供给 LLM,同时删除不相关的视觉信息,这减少了 LLM 学习视觉语言对齐的负担。

  • create dataset、loader
    # 创建检索数据集
    print("Creating retrieval dataset")
    train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)  

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]

    # 创建数据加载器
    train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                        batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
                                                        num_workers=[4,4,4],
                                                        is_trains=[True, False, False], 
                                                        collate_fns=[None,None,None])
  • model
    model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 
                             vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 
                             queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])

    model = model.to(device)   
    
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module   

    optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 
    
    best = 0
    best_epoch = 0
  • training 分数
    for epoch in range(0, config['max_epoch']):    
        if not args.evaluate:        
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)
                
            cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
            
            train_stats = train(model, train_loader, optimizer, epoch, device, config)  
            
        score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
        score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
    
        if utils.is_main_process():  
      
            val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)  
            print(val_result)
                                
            if val_result['r_mean']>best:
                save_obj = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'config': config,
                    'epoch': epoch,
                }
                torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))  
                best = val_result['r_mean']        
                best_epoch = epoch  
                
                test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) 
                print(test_result)
  • 遍历数据集 将数据移动到设备上
    for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # 将数据移动到设备上
        image = image.to(device,non_blocking=True)   
        idx = idx.to(device,non_blocking=True)   
       
        # 计算alpha
        if epoch>0:
            alpha = config['alpha']
        else:
            alpha = config['alpha']*min(1,i/len(data_loader))


        loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)                  
        loss = loss_ita + loss_itm
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        
        metric_logger.update(loss_itm=loss_itm.item())
        metric_logger.update(loss_ita=loss_ita.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

大模型在大数据量下 大模型训练耗时所以冻结住 文本和图像编码器冻结住不更新 但是不更新的话 gap太大了 所以引入小的transformer 作为桥梁 把图文gap连接 Q-Former参数少 第一阶段表征学习阶段 生成学习阶段 输入可学习的queries的一个embedding 和文本 作为查询器 和当前文本最相关的图像信息是什么 抹除不相关信息 提取视觉信息(与目标文本相关的)输入大模型LLM 下游任务
表征学习 输入图像 文本 冻结的图像编码器会得到图像的embedding
IGT 计算Q隐藏T 根据Q输出T T根据当前预测以后的
生成 视觉信息经过FC

模型 基于encoder的模型很难做生成任务 没有检索任务encoder-decoder-based
MED生成式语言任务 第一个 cls embedding做图文对比学习
encode decode生成任务因果自注意力 LM从bert的形式换成gpt的形式
人工标注的微调

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值