Co-Attack:针对VLP模型的对抗样本攻击

论文:Towards Adversarial Attack on Vision-Language Pre-training Models

代码: https://github.com/adversarial-for-goodness/Co-Attack

1. 攻击VLP模型的思路和本文贡献

1.1 改进攻击思路

为了攻击VLP模型的嵌入表示,图像和文本的对抗性扰动的对抗扰动应该协同考虑,而不是独立考虑。下图展示了针对视觉蕴涵任务对 ALBEF 进行对抗攻击的样例。 结果表明,只有扰动图像才能成功地将预测从“蕴含”变为“矛盾”。 然而,通过独立扰动图像和文本而不考虑它们的交互作用,攻击会失败,因为两种单模式攻击可能会相互冲突,并导致抵消 1 + 1 < 1 效应,所以为了攻击VLP模型的嵌入表征embedding,图像和文本的对抗性扰动应协同考虑,而不是单独考虑

1.2 本文贡献:

(1)分析了两种典型VLP模型架构(Fused VLP Model和Aligned VLP Model)以及三个下游V+L任务(文本-图像检索)的对抗性攻击性能

(2)针对VLP模型,提出了一种新的多模态对抗性攻击方法。通过考虑不同模态攻击之间的一致性,它协同组合多模态扰动以实现更强的对抗性攻击.也就是提出co-attack,其对抗损失的设计协同整合了图像和文本的对抗性扰动

注:本博客仅关注于对Aligned VLP Model如Clip模型的攻击

2. 论文前置知识

2.1 BERT-Attack

2.1.1 基于单词替换的文本对抗样本生成

在计算机视觉类的任务中,为了提高模型的泛化能力,我们一般会通过诸如图像反转,图像缩放等数据扩充策略来提高模型的泛化能力。但是数据扩充在自然语言处理的领域中却是复杂且很难实用的,常见的文本扩充策略例如近义词替换,embedding相近词替换,句子shuffle等策略并不能保证扩充策略的完全正确。其中一个重要的原因是文本数据是一个信息密集的信息载体,一个轻微的扰动就有可能完全破坏原句字要表达的内容。基于这个现状,文本对抗(Text Adversarial)成为了NLP领域一个非常重要的方向。文本对抗指的是基于现有模型和规则,通过生成容易导致模型误识的样本,来对模型的参数进行优化,从而提高现有文本模型的鲁棒性

2.1.2 TextFooler攻击

TextFooler就是其中一种基于单词替换策略的文本对抗样本生成方法,基本思路如下

(1)输入一段句子,找出其中最重要的单词Top k,然后针对这K个单词进行替换,替换后的句子需要满足以下前提:

  • 替换单词后,句子语义基本保持不变;
  • 被替换的句子在语法,流畅性上和上下文保持匹配;
  • 模型会在这个新句子上产生错误的预测结果。

(2)进行近义词提取,TextFooler的策略是使用Mrkšić等人的方法计算两个近义词的相似度,然后通过两个单词的cosine距离为每个关键词提取Top-N个同义词,并构建一个候选集,表示为: 。TextFooler选取了相似度大于 的Top-N个词来作为近义词,在它的实验中 N=50 ,=0.7

(3)将替换的词填入近义词,进行词性检查,语义相似度检验,规则过滤等生成文本对抗样本

基于笔记:文本对抗之TextFooler - 知乎 (zhihu.com)

2.2.3 Bert Attack 

针对TextFooler的改进思路

BERT-Attack[1]也是一个和TextFooler类似的词替换模型,他们的第一步都是先查找句子中的关键词。不同的是在生成阶段,BERT-Attack是借助了BERT[3]的掩码语言模型(Masked Language Model,MLM)的用于预测被Mask掉的单词的天然特性,将其应用到了近义词生成部分。对比TextFooler,基于MLM的BERT-Attack的效率要高很多,而且在语法和语义上也正确和连续,BERT-Attack的思想是使用一个BERT作为对抗生成器来生成对抗样本,使用另外一个BERT作为被攻击的模型,目标设计提高被攻击BERT的鲁棒性,BERT-Attack有两个核心步骤(图1):

  1. 为目标模型找到易攻击的词c1...ck,这个易攻击的词往往是帮助模型做出判断的关键词;
  2. 对易攻击词进行扰动或者替换,利用MLM算法生成每个关键词的Top-K个扰动,不断对关键词的扰动进行尝试,直到攻击成功。

查找易攻击词

Bert替换策略

传统的近义词替换策略一般是使用人工设计的规则进行替换[2],这个方案的最大问题是忽略了被替代单词的上下文信息,因此在流畅性以及语义一致性上存在问题,而且人工设计的诸多计算规则是非常耗时的。

在BERT的掩码语言模型中,它会输出掩码位置在整个字典上的概率分布,通过在海量语料上对BERT的训练,而且基于MLM的上下文替换是考虑到单词的上下文信息了的,因此掩码语言模型往往在掩码处预测的内容在流畅性和语义性上都比较真实。而且这个策略基本不需要其它外部评估模型,样本的生成速度也非常快。不同于训练BERT时的对每个单词以一定的概率进行掩码,BERT-Attack输入到模型中的是没有经过任何掩码的句子,这样生成的特征向量比掩码后的特征向量更具有语义一致性。

Bert Attack伪代码

基于笔记:文本对抗之BERT-Attack - 知乎 (zhihu.com)

2.2 VLP模型和下流任务

在这项工作中,作者考虑了 CLIP、ALBEF和TCL进行评估。CLIP属于aligned VLP,ALBEF和TCL属于Fused VLP

图像文本检索包含两个子任务:图像到文本检索(TR)和文本到图像检索(IR)。对于ALBEF和TCL,无论是TR还是IR,首先计算所有图像文本对的e_i和e_t之间的特征相似度得分,以检索Top-N候选者,然后使用e_m计算出的图像文本匹配得分用于排名。 CLIP上的TR和IR任务执行得更直接。 排名结果仅基于e_i和e_m之间的相似度。

视觉蕴涵(Visual Entailment)是一项视觉推理任务,用于预测图像和文本之间的关系是蕴涵关系、中性关系还是矛盾关系。ALBEF 和 TCL 都将 VE 视为三向分类问题,并使用多模态编码器 [CLS] 标记表示的完整层来预测类别概率。

视觉定位(Visual Grounding)根据相应输入文本的描述来定位输入图像中的区域。ALBEF 扩展了 Grad-CAM,并使用导出的注意力图对检测到的提案进行排序。

3 Co-Attack 核心思想

3.1 攻击方法选择

针对图像使用PGD Attack,曲中Ei(xi`)是对抗图像的embedding,Ei(xi)是原图像的embedding,

Loss选择KL散度

针对文本使用Bert Attack,xt'代表对原文本xt tokens的替换

3.2 协作多模态攻击Co-Attack

尽管上述分析发现同时扰动文本和视觉模态比单独扰动一种模态更有效。 然而,如最开始所讨论的,独立攻击两种模式可能会导致 1 + 1 < 1 抵消效应。因此,作者通过开发一种协作多模态对抗攻击解决方案来解决这个问题,称为协作多模态对抗攻击(Co-Attack)。 这使得能够共同对图像模态和文本模态进行攻击。 联合攻击的目的是鼓励扰动的多模态嵌入远离原始的多模态嵌入,或者扰动的图像模态嵌入远离扰动的文本模态嵌入。由于 Co-Attack 可以适用于攻击多模态和单模态嵌入,因此它适用于融合 VLP 和对齐的 VLP 模型

本文着重针对的Clip攻击中,使得对抗图像xi`的embedding和原图像xi的embedding以及对抗文本xt'的embedding的KL散度越来越大,α是一个正整数,代码中设置为3

论文没有伪代码展示,基于笔记:[论文总结] Co-Attack: Towards Adversarial Attack on Vision-Language Pre-training Models - 知乎 (zhihu.com)

4. 针对Clip攻击的核心代码

4.1 外部处理

Clip部分代码的功能是:生成干扰文本和干扰图像,并用于VLP下游任务文本检索图像以及图像生成文本任务中评估对抗效果

def retrieval_eval(model, ref_model, data_loader, tokenizer, device, config):
    images_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    image_attacker = ImageAttacker(config['epsilon'] / 255., preprocess=images_normalize, bounding=(0, 1), cls=args.cls)
    text_attacker = BertAttack(ref_model, tokenizer, cls=args.cls)
    ....
    image_feats = torch.zeros(num_image, model.visual.output_dim)
    text_feats = torch.zeros(num_text, model.visual.output_dim)
    for images, texts, texts_ids in data_loader:
        images = images.to(device)
        if args.adv != 0:
            #adv=4,run_before_fusion函数的功能是生成一个batch size
            #大小的对抗图像images和对抗文本texts
            images, texts = multi_attacker.run_before_fusion(images, texts, adv=args.adv, num_iters=config['num_iters'], max_length=77,
                                                             alpha=args.alpha)
        images_ids = [data_loader.dataset.txt2img[i.item()] for i in texts_ids]
        with torch.no_grad():
            images = images_normalize(images)
            #获取对抗文本和对抗图像经过Clip的encoder之后的embedding信息
            output = model.inference(images, texts)
            #将对抗图像和对抗样本的embedding和下标对应
            image_feats[images_ids] = output['image_feat'].cpu().float().detach()
            text_feats[texts_ids] = output['text_feat'].cpu().float().detach()
    #矩阵乘法
    sims_matrix = image_feats @ text_feats.t()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Evaluation time {}'.format(total_time_str))
    #返回以后面进行效果评估,sims_matrix作为图像到文本的分数 sims_matrix.t()作为文件到图像的分数
    return sims_matrix.cpu().numpy(), sims_matrix.t().cpu().numpy()

4.2 Co-Attack针对Clip模型攻击

    def run_before_fusion(self, images, text, adv, num_iters=10, k=10, max_length=30, alpha=3.0):
        if adv == 2 or adv == 3:
            images_adv = self.image_attacker.run_trades(self.net, images, num_iters)
        #Clip攻击
        elif adv == 4:
            device = images.device
            #初始化对抗图像image_attack
            image_attack = self.image_attacker.attack(images, num_iters)
            #KL散度作为损失函数
            criterion = torch.nn.KLDivLoss(reduction='batchmean')
            with torch.no_grad():
                #针对文本的bert attack
                text_adv = self.text_attacker.attack(self.net, text, k)
                text_input = self.tokenizer(text_adv, padding='max_length', truncation=True, max_length=max_length,
                                            return_tensors="pt").to(device)
                text_adv_output = self.net.inference_text(text_input)
                #默认cls为null,获取对抗文本的embedding
                if self.cls:
                    #如果设置了Cls,只取CLS对应的embedding信息
                    text_adv_embed = text_adv_output['text_embed'][:, 0, :].detach()
                else:
                    text_adv_embed = text_adv_output['text_embed'].flatten(1).detach()

            with torch.no_grad():
                #该行代码操作是F.normalize(self.encode_image(images), dim=-1
                image_output = self.net.inference_image(self.image_normalize(images))
                #默认cls为null,获取对抗图像的embedding
                if self.cls:
                    #如果设置了cls,只取对应位置的embedding信息
                    image_embed = image_output['image_embed'][:, 0, :].detach()
                else:
                    image_embed = image_output['image_embed'].flatten(1).detach()

            for i in range(num_iters):
                #继续执行PGD attack
                image_adv = next(image_attack)
                #该行代码的操作是:{'image_embed': F.normalize(self.encode_image(image_adv), dim=-1)}
                image_adv_output = self.net.inference_image(image_adv)
                #默认cls=null,获取对抗图像的embedding
                if self.cls:
                    image_adv_embed = image_adv_output['image_embed'][:, 0, :]
                else:
                    image_adv_embed = image_adv_output['image_embed'].flatten(1)
                #计算论文的第一条损失: KL_loss( E(xi′) , E(xi) ), 其中image_adv_embed=Ei(xi′) image_embed=Ei(xi)
                loss_image_trades = criterion(image_adv_embed.log_softmax(dim=-1), image_embed.softmax(dim=-1).repeat(self.repeat, 1))
                #计算论文的第二条损失: KL_loss( E(xi′) , Et(xt′) ),其中text_adv_embed=Et(xt′)
                loss_adv_text = criterion(F.normalize(image_adv_embed, dim=-1).log_softmax(dim=-1),
                                          F.normalize(text_adv_embed, dim=-1).softmax(dim=-1).repeat(self.repeat, 1))
                #最终的对抗损失 Loss= KL_loss( E(xi′) , E(xi) ) +  α*KL_loss( E(xi′) , Et(xt′) )
                loss = loss_image_trades + alpha * loss_adv_text
                loss.backward()
            # 继续执行PGD attack,获取对抗损失对噪声的梯度符号并更新噪声值,然后中断PGD攻击返回一轮噪声添加后的样本
            images_adv = next(image_attack)

        else:
            images_adv = images

        if adv == 1 or adv == 3 or adv == 4 or adv == 5:
            with torch.no_grad():
                text_adv = self.text_attacker.attack(self.net, text, k)
        else:
            text_adv = text

        return images_adv, text_adv

4.3 PGD攻击代码

    def attack(self, image, num_iters):
        if self.random_init:
            self.delta = random_init(image, self.norm_type, self.epsilon)
        else:
            self.delta = torch.zeros_like(image)

        if hasattr(self, 'kernel'):
            self.kernel = self.kernel.to(image.device)

        if hasattr(self, 'grad'):
            self.grad = torch.zeros_like(image)


        epsilon_per_iter = self.epsilon / num_iters * 1.25

        for i in range(num_iters):
            self.delta = self.delta.detach()
            self.delta.requires_grad = True

            image_diversity = self.input_diversity(image + self.delta)
            #plt.imshow(image_diversity.cpu().detach().numpy()[0].transpose(1, 2, 0))
            #图片初始化等操作
            #images_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
            if self.preprocess is not None:
                image_diversity = self.preprocess(image_diversity)
            #暂停当初函数的执行临时返回image对抗样本,直到下一个调用next()再继续执行
            #python中yield的目的是节省内存消耗
            yield image_diversity

            grad = self.get_grad()
            grad = self.normalize(grad)
            self.delta = self.delta.data + epsilon_per_iter * grad
            # constraint 1: epsilon
            self.delta = self.project(self.delta, self.epsilon)
            # constraint 2: image range
            self.delta = torch.clamp(image + self.delta, *self.bounding) - image

        yield (image + self.delta).detach()

4.4 Bert Attack攻击

    def attack(self, net, texts, k=10, num_perturbation=1, threshold_pred_score=0.3, max_length=30, batch_size=32):
        device = self.ref_net.device

        text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)

        # substitutes
        mlm_logits = self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
        word_pred_scores_all, word_predictions = torch.topk(mlm_logits, k, -1)  # seq-len k

        # original state
        origin_output = net.inference_text(text_inputs)
        if self.cls:
            origin_embeds = origin_output['text_embed'][:, 0, :].detach()
        else:
            origin_embeds = origin_output['text_embed'].flatten(1).detach()

        criterion = torch.nn.KLDivLoss(reduction='none')
        final_adverse = []
        for i, text in enumerate(texts):
            # word importance eval
            important_scores = self.get_important_scores(text, net, origin_embeds[i], batch_size, max_length)

            list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)

            words, sub_words, keys = self._tokenize(text)
            final_words = copy.deepcopy(words)
            change = 0

            for top_index in list_of_index:
                if change >= num_perturbation:
                    break

                tgt_word = words[top_index[0]]
                if tgt_word in filter_words:
                    continue
                if keys[top_index[0]][0] > max_length - 2:
                    continue

                substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]]  # L, k
                word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]

                substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
                                             threshold_pred_score)


                replace_texts = [' '.join(final_words)]
                available_substitutes = [tgt_word]
                for substitute_ in substitutes:
                    substitute = substitute_

                    if substitute == tgt_word:
                        continue  # filter out original word
                    if '##' in substitute:
                        continue  # filter out sub-word

                    if substitute in filter_words:
                        continue
                    '''
                    # filter out atonyms
                    if substitute in w2i and tgt_word in w2i:
                        if cos_mat[w2i[substitute]][w2i[tgt_word]] < 0.4:
                            continue
                    '''
                    temp_replace = copy.deepcopy(final_words)
                    temp_replace[top_index[0]] = substitute
                    available_substitutes.append(substitute)
                    replace_texts.append(' '.join(temp_replace))
                replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
                replace_output = net.inference_text(replace_text_input)
                if self.cls:
                    replace_embeds = replace_output['text_embed'][:, 0, :]
                else:
                    replace_embeds = replace_output['text_embed'].flatten(1)

                loss = criterion(replace_embeds.log_softmax(dim=-1), origin_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
                loss = loss.sum(dim=-1)
                candidate_idx = loss.argmax()

                final_words[top_index[0]] = available_substitutes[candidate_idx]

                if available_substitutes[candidate_idx] != tgt_word:
                    change += 1

            final_adverse.append(' '.join(final_words))

        return final_adverse

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值