FGM实现

import tensorflow as tf

class FGM():
    def __init__(self, model):
        self.model = model
        self.backup = {}

    def attack(self, gradients, epsilon=1.0, emb_name=''):
        # 遍历模型中的所有可训练参数,并施加对抗扰动
        for var, grad in zip(self.model.trainable_variables, gradients):
            if emb_name in var.name:  # 仅对Embedding层进行对抗攻击
                self.backup[var.name] = tf.identity(var)  # 备份原始参数
                norm = tf.norm(grad)
                if norm != 0 and not tf.math.is_nan(norm):
                    r_at = epsilon * grad / norm  # 计算扰动
                    var.assign_add(r_at)  # 对参数施加扰动

    def restore(self, emb_name='embed_net'):
        # 恢复原始Embedding参数
        for var in self.model.trainable_variables:
            if emb_name in var.name and var.name in self.backup:
                var.assign(self.backup[var.name])  # 恢复备份的参数
        self.backup = {}  # 清空备份

@tf.function
def train_step(model, optimizer, loss_fn, x_batch, y_batch, fgm, epsilon=1.0):
    # 正常的前向传播和梯度计算
    with tf.GradientTape() as tape:
        predictions = model(x_batch, training=True)
        loss = loss_fn(y_batch, predictions)

    # 计算普通的梯度
    gradients = tape.gradient(loss, model.trainable_variables)

    # 应用正常的梯度更新
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # 执行FGM对抗攻击
    fgm.attack(gradients, epsilon=epsilon)  # 传递梯度用于计算对抗扰动
    
    # 重新计算对抗扰动后的梯度
    with tf.GradientTape() as tape_adv:
        predictions_adv = model(x_batch, training=True)
        loss_adv = loss_fn(y_batch, predictions_adv)

    gradients_adv = tape_adv.gradient(loss_adv, model.trainable_variables)
    
    # 应用对抗扰动后的梯度更新
    optimizer.apply_gradients(zip(gradients_adv, model.trainable_variables))

    # 恢复Embedding层参数
    fgm.restore()

    return loss, loss_adv

FGM类:负责保存模型中的Embedding层的参数备份,执行对抗攻击,并恢复参数。

attack: 在给定的epsilon下,计算对Embedding层的梯度扰动并加到原始参数上。
restore: 恢复之前备份的Embedding层参数,防止对抗扰动持续影响模型。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值