advGAN代码笔记

这是一种使用GAN来生成对抗样本的模型

代码:

首先来看一个训练过程

        代码中首先训练的是D

        首先用generator生成干扰项 perturbation,然后与原图相加形成对抗样本 adv_images

        当然训练一个D的loss分为了两部分,loss_D_real旨在拉近吃正样本之后的输出与1的距离

loss_D_fake旨在拉近吃负样本之后与0的距离,这里的负样本就是对抗样本,输入的时候不要忘了detach掉

        # optimize D
        # x are the input images
        for i in range(1):
            perturbation = self.netG(x)  # torch.Size([128, 1, 28, 28])

            # add a clipping trick
            adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
            adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

            self.optimizer_D.zero_grad()
            pred_real = self.netDisc(x)
            loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
            loss_D_real.backward()

            pred_fake = self.netDisc(adv_images.detach())
            loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
            loss_D_fake.backward()
            loss_D_GAN = loss_D_fake + loss_D_real
            self.optimizer_D.step()

        训练G的过程就有些复杂了

        首先要G的Gan损失的训练目标是让自己生成的对抗样本,在D看起来和正样本1相近

下方的retain_graph = True的意思是保留当前方向传播的计算图,可以做梯度累加

可以参见这两篇博客https://www.cnblogs.com/picassooo/p/13748618.html

https://www.cnblogs.com/picassooo/p/13818952.html

            self.optimizer_G.zero_grad()

            # cal G's loss in GAN
            pred_fake = self.netDisc(adv_images)
            loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
            loss_G_fake.backward(retain_graph=True)

        接下来就是限制扰动大小的损失

        这里设计的是一个batch之中所有图片的矩阵二范数都不能太大

            # calculate perturbation norm
            C = 0.1
            loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))

        接下来就是样本对抗损失

        onehot_labels这里的实现是比较优雅的,总体功能是根据手写数字的类别转换为onehot编码的格式,torch.eye的功能就是得到onehot编码,然后使用lables变量中对应的类别把他提取出来

        real的功能,按照我粗浅的理解,是得到网络针对一个batch中所有对抗样本预测正确的概率。other的功能,是得到了网络针对一个batch中的所有对抗样本预测为错误的类别中,可能性最大的概率。

        那个torch.max(real-other,0)的功能,按照我粗浅的理解,首先看real-ohter的部分,因为损失函数都是梯度下降的,最小化这个损失函数,相当于训练模型让real更小,other更大,犯错的概率越大。之所以要与0相max,也许是小于0的时候,other已经大于real了,然后没必要训练这个部分了?

        最后一个损失函数我可能理解的不正确,还是要看一下那个C&W模型是怎么设计的

            # cal adv loss
            logits_model = self.model(adv_images)
            probs_model = F.softmax(logits_model, dim=1)
            onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]  # torch.Size([128, 10])

            # C&W loss function
            real = torch.sum(onehot_labels * probs_model, dim=1)  # [128]
            other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
            zeros = torch.zeros_like(other)
            loss_adv = torch.max(real -
                                 other, zeros)
            loss_adv = torch.sum(loss_adv)

接下来就是把这两个loss乘以一个超参权重,然后backward就好了

            adv_lambda = 10
            pert_lambda = 1
            loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
            loss_G.backward()
            self.optimizer_G.step()

        return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: Python 代码笔记是对 Python 程序代码的解释和说明。它可以帮助你理解代码的工作原理,并在以后更好地维护和编写代码。常用的代码笔记格式有注释、文档字符串等。示例代码: ```python # 计算平方 def square(x): """ 返回x的平方 """ return x*x print(square(4)) ``` 在上面的代码中,`# 计算平方`是注释,`"""返回x的平方"""`是文档字符串。 ### 回答2: Python代码笔记是程序员在学习和实践Python编程语言时记录的一种文档。它包括通过编写实际的Python代码示例来记录各种语法、函数、模块、库和算法的用法和应用。 Python代码笔记通常用于记录和整理编程语言的基本知识,并用代码示例来演示这些知识的具体使用。因为Python语言本身较为简洁易读,因此在代码笔记中使用Python语言编写示例代码非常方便。 通过编写Python代码笔记,程序员可以更好地理解和掌握Python编程语言的特性和用法。而且代码笔记还可以作为程序员的参考资料,帮助他们在遇到问题时快速找到解决方案并进行复用。 除了记录基本知识之外,Python代码笔记还可以用于记录程序员在实际项目中遇到的问题和解决方案。通过记录这些问题和解决方案,程序员可以在未来的项目中预防和避免相同的问题,并且能够提高自己的编程技巧和经验。 总之,Python代码笔记是程序员学习和实践Python编程语言时记录的一种文档。它可以帮助程序员整理知识、提高编程技巧,并成为他们解决问题和提高效率的有力工具。 ### 回答3: Python代码笔记是程序员在学习和使用Python语言时记录的一种方式。它可以包括以下内容: 首先,Python代码笔记通常会记录Python代码的基本语法和用法。这些笔记会列举Python的关键字、变量类型、运算符、控制流语句等基本知识点,以便在需要的时候进行快速查阅和复习。 其次,Python代码笔记还会记录一些常用的Python库和模块的使用方法。Python具有丰富的第三方库和模块,如numpy、pandas、matplotlib等,这些库在数据处理、科学计算、绘图等领域都有广泛的应用。通过记录库和模块的使用方法,可以帮助程序员实现特定的功能或解决具体的问题。 此外,Python代码笔记还会记录一些常见的编程技巧和经验。比如如何提高代码的效率、如何优化算法、如何进行调试等等。这些技巧和经验是程序员在实际开发中积累的宝贵资料,可以帮助他们更好地解决问题和提高工作效率。 最后,Python代码笔记还可以记录一些项目示例和实践经验。当程序员在开发具体的项目时,他们会遇到各种问题和挑战,记录下来的项目示例和实践经验可以为他们以后的开发工作提供参考和借鉴。这些实践经验可以包括项目的架构设计、数据库操作、接口调用等方面的知识。 综上所述,Python代码笔记是程序员学习和使用Python语言的重要辅助工具,它通过记录基本语法、常用库和模块的使用、编程技巧和经验以及项目示例和实践经验等内容,帮助程序员提高开发效率,解决问题,并不断提升自己的编程能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值