文章目录
1. 对抗样本生成原理
使用快速梯度符号算法(FGSM: Fast Gradient Sign Method)生成对抗样本
用原始图像初始化对抗样本,通过损失函数计算梯度,根据FGM算法迭代更新对抗样本,直到满足最大迭代次数,或者对抗样本预测值达到预期为止。
梯度上升的方式,
2. 实验与结果
"""
#-*- coding = utf-8 -*-
#@Time: 2022 12 27 上午10:08
#@Author:JFZ
#@File:02_基于梯度的对抗样本.py
#@Software: PyCharm
"""
import cv2
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
# 图像读取
image_path = "/home/harry/LOCAL/Python_Demo/AI对抗样本入门/AI对抗学习/picture/cropped_panda.jpg"
image = cv2.imread(image_path)
image = cv2.resize(image, (224, 224))
image_backup = image.copy()
image = image / 255.0
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image = ((image - mean) / std).astype("float32")
tensor_image = torchvision.transforms.ToTensor()(image)[None, ...]
# 模型加载
model = models.alexnet(pretrained=True).eval()
for param in model.parameters():
param.requires_grad = False
src_pred_index = model(tensor_image).argmax(dim=1) # tensor[388]
# 基于梯度的对抗样本
adv_tensor_image = tensor_image.clone()
adv_tensor_image.requires_grad = True
optimizer = torch.optim.Adam(params=[adv_tensor_image])
epochs = 100
target = 150
target = torch.tensor([target])
loss_func = nn.CrossEntropyLoss()
for epoch in range(epochs):
output = model(adv_tensor_image)
loss = loss_func(output, target)
lable_index = torch.argmax(output, dim=1).item()
print("epoch {} | loss = {:.4f} | lable = {}".format(epoch + 1, loss, lable_index))
if lable_index == target.item():
break
optimizer.zero_grad() # 手动清零,触发反向传播
loss.backward()
# 这里就可以计算出adv_tensor_image的梯度值,然后根据FGSM来更新对抗样本了
adv_tensor_image.data = adv_tensor_image.data - 0.01 * torch.sign(adv_tensor_image.grad.data) # 基于FGSM,
# 可视化
adv_image = adv_tensor_image.detach()[0].permute([1, 2, 0]).numpy()
adv_image = np.clip((adv_image * std + mean) * 255, 0, 255).astype("uint8")
plt.subplot(131)
plt.imshow(image_backup)
plt.title(src_pred_index)
plt.subplot(132)
plt.imshow(adv_image)
plt.title(lable_index)
plt.subplot(133)
plt.imshow(adv_image - image_backup)
plt.title("difference")
plt.show()
epoch 1 | loss = 10.8752 | lable = 388
epoch 2 | loss = 6.3648 | lable = 194
epoch 3 | loss = 4.2419 | lable = 360
epoch 4 | loss = 3.0041 | lable = 684
epoch 5 | loss = 0.8586 | lable = 150