import torch import torchvision.transforms as transforms from torchvision.models import resnet50 from advertorch.attacks import PGDAttack from PIL import Image import matplotlib.pyplot as plt import requests # 加载预训练的 ResNet 模型和测试图像 model = resnet50(pretrained=True) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) image_path = 'C:/Users/Administrator/Desktop/flower.jpg' image = transform(Image.open(image_path)).unsqueeze(0) # 定义损失函数和 PGD 攻击器 criterion = torch.nn.CrossEntropyLoss() adversary = PGDAttack(model, loss_fn=criterion, eps=0.01, nb_iter=40, eps_iter=0.01) # 运行 PGD 攻击生成对抗样本 label = torch.tensor([985]) #daisy在ImageNet的label为985,如若使用其他图片可以直接输出一次识别结果即可 adv_image = adversary.perturb(image, label) # 可视化原始图像和对抗样本 original_image = transforms.ToPILImage()(image.squeeze(0)) adversarial_image = transforms.ToPILImage()(adv_image.squeeze(0)) plt.subplot(1, 2, 1) plt.imshow(original_image) plt.title('Original Image') plt.subplot(1, 2, 2) plt.imshow(adversarial_image) plt.title('Adversarial Image') plt.show() # 对抗样本的识别结果 with torch.no_grad(): output_original = model(image) output_adversarial = model(adv_image) _, predicted_original = torch.max(output_original, 1) _, predicted_adversarial = torch.max(output_adversarial, 1) # 获取 ImageNet 类别标签 labels = requests.get("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json").json() print(f"Original Prediction: {labels[predicted_original.item()]}") print(f"Adversarial Prediction: {labels[predicted_adversarial.item()]}")
PGD攻击生成对抗样本
最新推荐文章于 2024-04-16 09:44:45 发布