ALBEF(Align before Fuse: Vision and LanguageRepresentation Learning with Momentum Distillati) 算法阅读记录

本文详细描述了在ALBEF项目中进行环境配置、数据预处理、前向流程,包括ITC和MLM损失的计算方法,特别关注了BertForMaskedLM的使用和MoCo算法的应用。作者解决了tokenizer_class错误,并分享了文本和图片编码的处理过程。
摘要由CSDN通过智能技术生成

论文地址:👻

目录

一、环境配置

二、训练时的前向过程

数据预处理

前向流程

计算ITC损失

ITM损失的计算

MLM损失的计算


一、环境配置

至于数据集,我这有自己制作的小批量的测试版并与json文件配套的(一共加起来不到10M)。制作过程以及具体地环境配置以及执行时的报错问题可以参考这篇 😜

本次不做完整训练,依然只做测试和阅读。

python -m torch.distributed.launch --nproc_per_node=1 --use_env Pretrain.py --config /root/data/zjx/Code-subject/ALBEF/ALBEF-main/configs/Pretrain.yaml --output_dir /root/data/zjx/Code-subject/ALBEF/ALBEF-main/output/Pretrain

出现错误

TypeError: add_code_sample_docstrings() got an unexpected keyword argument 'tokenizer_class'

解决办法(来源于github官网)

采用最后一个人的建议

processor_class 替换 tokenizer_class

那个说hugging face 的去看了一下

确实发现有的 -减号 表示删去了 +加号表示新增了。其实就是替换了,后面的赋值还是一样。其实你一打 pro就自动补齐出来了,上图只是确定一下是不是作用没有变。

有好几处这样的错误,直接用IDE   执行Ctrl+F 搜索 下面去改

 @add_code_sample_docstrings

对于 text Encoder 的初始化,做了如下改动

​
self.text_encoder_m = BertForMaskedLM.from_pretrained('/root/data/zjx/Code-subject/BLIP/bert_base_uncased', config=bert_config, local_files_only = True)

​

具体地,之前一直报

ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.

这样地错误,所以,利用本地的文件夹时在后面加上local_files_only = True。

二、训练时的前向过程

数据预处理

对于caption首先会将其中的一些不正规字符替换掉。对于图像,会经过数据增强操作进行transform。最终返回 caption image 样本对儿。

对text 进行 量化

{'input_ids': tensor([[  101,  1037,  2485,  2298,  2012,  3467,  2962, 24019,  7657,  9587,
         26328,  2094,  4777,  1998,  2304,  2030,  2829,  4230,  6302,  2011,
          4463,  4241, 21179],
        [  101,  1037,  4799,  3242,  1999,  1996,  2542,  2282,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  1037,  2485,  2298,  2012,  3467,  2962, 24019,  7657,  9587,
         26328,  2094,  4777,  1998,  2304,  2030,  2829,  4230,  6302,  2011,
          4463,  4241, 21179],
        [  101,  1037,  4799,  3242,  1999,  1996,  2542,  2282,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')}

这之前还包括 文本queue 和 imgae queue的初始化

self.image_queue = nn.functional.normalize(self.image_queue, dim=0)  # (256,65536)
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)  # (256,65536)

前向流程

1、image 送入 Image encoder, 输出 image表征, (B,257,768) ,257是因为在ViT的前向传播过程中在第0维加入了 cls token

2、 将image representation 的 类别cls token embeded 送入线性层映射到固定维度 768-->256 输出 image representation (B,256),然后进行L2 归一化

3、 文本送入 text Encoder 生成text 表征,和图像一样,将cls token embeded 映射到256固定维度。 (B,256)然后进行L2归一化

4、 进行 MoCo 算法的动量更新,具体可以参考这里🐮

MoCo算法中设立了query 和 key,进而创建两个encoder 即 Encoder_q 和 Encoder_k。这里创建的 visual_encoder_m 和 text_encoder_m 和 MoCo中的Encoder_k一样,用来更新queue。

        -根据 MoCo算法来创建 负样本

5、

计算ITC损失

代码中的 image_embeds_m和 text_encoder_m 除了起到了上面 4 中所述的作用后,还有着另一层含义。即文中提及的 单模态动量蒸馏encoder

除了需要计算文中的公式1

之外,最终的ITC损失是由公式6确定的

为什么要这么做呢?文中指出

而且,根据文中的实验细节

\alpha在第一个epoch内从0到0.4线性递增。

那这就合理了。因为在训练开始时,模型还没有学到东西,所以,不必担心负样本的文本会匹配到图像的问题,随着训练的进行,倒是有可能发生的,不过占比很小。

代码中它是如何实现的呢?

1、根据MoCo,制作负样本

image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

注意这里的真是Image 和 text 的编码表示在前,queue 中的key在后。因此,其正样本匹配的标签应该为

sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)  # label (4,65540)
sim_targets.fill_diagonal_(1)

也就是只有主要对角线的 文本和 图像 是匹配的,因为它计算的公式1

sim_i2t_m = image_feat_m @ text_feat_all / self.temp   #  (4,65540)  负样本对的 softmax-normalized
sim_t2i_m = text_feat_m @ image_feat_all / self.temp   # (4,65540)

是矩阵乘法,所以 {v_{cls}}_1{w_{cls}}_1{v_{cls}}_2{w_{cls}}_2, ..., {v_{cls}}_b{w_{cls}}_b 才是Image-Text匹配的(b为batch size 大小),其余都是负样本。

根据文中提出的动量蒸馏,它 的pseudo targets为

            sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets  # (4,65540)
            sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets  # (4,65540)

注意这个是取了softmaxt的,所以这个伪目标是个概率分布 q,这个相当于伪标签。

根据公式6,还得求一下 p 的概率分布,也就是模型所预测的Image-text相似度的概率分布

        sim_i2t = image_feat @ text_feat_all / self.temp   # (4,65540)
        sim_t2i = text_feat @ image_feat_all / self.temp   # (4,65540)

(这个目前还不是概率分布,转成概率分布的操作在下一步一并完成)

注意这里的区别。这里的image feat 和 text feat 相当于MoCo算法中的query,来自最终需要训练得到的目标模型。

最后,根据公式6,计算KL散度,这里附上KL散度公式

        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 

因为 伪目标这个概率分布 是已知的,所以,就跟交叉熵的公式推导一样,去掉常数项,其大小只与

-P(x)log^{Q(x))}

有关,公式6以及代码中谁是P(x),谁是Q(x)一代入就一目了然了。

6、执行MoCo算法中的 更新 queue 操作 

7、

ITM损失的计算

这部分要利用到multimodal encoder来完成。

Image encoder 和 text encoder的输出 送入其中。 

我看了一下,这部分与我之前的帖子中的操作一样,具体参考这里🐱‍👤,这里不再赘述。

8、

MLM损失的计算

(1)、对文本句子进行预处理。 

如文中所述,这里有两个主要的点。

第一:标签,不是完成的句子,而是只计算 被掩码改掉的部分。也就是 标签 target 只在掩码mask处保留了 输入句子的单词信息,其余都不要,都掩盖掉。(还有cls 和 pad 的地方没有掩盖掉,也就行了保留)

第二:输入的处理。输入首先 (a)、要去掉盖住的单词部分(将标签处的token id 换为 mask token id 即可),其目的是告诉模型,这里是需要预测的单词所在的位置。 接下来,在  (a)步 操作后的基础之上,进行进一步处理。具体地,掩码的MasK部分进行操作,将掩码 mask id 随机替换为其它单词。也就是说,输入中的部分 掩码部分 被  一个随机的单词 重新补齐了

总结:正常来说,标签和输入是互补的,即input中去除了mask 部分的单词(单词 token id 转为 mask token id),标签只保留mask部分,其余去除(mask位置保留原单词 token id, 其余填充 -100)。但是文中提出了不正常的操作,即input中的随机一部分mask 被替换为 随机的单词。

最终的input 和 target 举例

# target
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7657,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100, 21179],
        [ -100,  -100,  -100,  -100,  -100,  1996,  2542,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  4777,  -100,  -100,  2030,  2829,  -100,  -100,  2011,
          -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  2282,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100]], device='cuda:0')
# input  其中 103 为 mask token id
tensor([[  101,  1037,  2485,  2298,  2012,  3467,  2962, 24019,   103,  9587,
         26328,  2094,  4777,  1998,  2304,  2030,  2829,  4230,  6302,  2011,
          4463,  4241,   103],
        [  101,  1037,  4799,  3242,  1999,   103,   103,  2282,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  1037,  2485,  2298,  2012,  3467,  2962, 24019,  7657,  9587,
         26328,  2094, 16216,  1998,  2304,   103,   103,  4230,  6302, 23970,
          4463,  4241, 21179],
        [  101,  1037,  4799,  3242,  1999,  1996,  2542, 20891,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]], device='cuda:0')

呀,句子太短,看不出来不正常的操作,哈哈(●ˇ∀ˇ●)。可能不正常的操作也是小概率事件😁

( 2)计算损失

交叉熵损失。注意这里

 y是词汇表的onehot。差不多和求最大似然一样。

=========================================================================

=========================================================================

连续几天看了几篇论文和代码,哈哈,对新领域有了很深的了解,哈哈,

~你这哪够十五斤啊,你这秤有问题啊 吸铁石☝~

~要是不熟我自己吃了它,满意了吧 😂~

=========================================================================

🎵~ Old but I'm not that old. Young but I'm not that bold. And I don't think the world is sold on just doing what we're told.~🎵 😛😎😝

  • 19
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

匿名的魔术师

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值