FGSM对抗训练(MNIST数据集)- pytorch实现

文章介绍了使用FastGradientSignMethod(FGSM)在LeNet网络上进行MNIST手写数字识别的对抗训练过程,包括生成对抗样本、测试模型对不同epsilon值的鲁棒性,以及对抗训练后的模型性能提升。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 概要

使用FGSM实现手写体识别的对抗训练与分析

2 整体架构流程

  1. 搭建LeNet网络训练MNIST的分类模型,并测试准确率。
  2. 生成不同epsilon值的对抗样本,送入训练好的模型,再次测试准确率,得出结果
  3. 基于训练集生成对抗数据集,与原训练集一同进行训练,再次测试不同epsilon值下模型的准确度。

2.1 搭建LeNet网络进行训练

2.1.1 导入库

import os.path
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import warnings
from matplotlib import MatplotlibDeprecationWarning

2.1.2 matplotlib异常处理

warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

2.1.3 载入torchvision中的MNIST数据集

**train_data = torchvision.datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_data = torchvision.datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

batch_size = 100

train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)

2.1.4 MINST图像示例

plt.figure(figsize=(10, 10))
n=1

# 取出n*batch_size张图片可视化

for i in range(n):
    images, labels = next(iter_dataloader)

    image_grid = torchvision.utils.make_grid(images,nrow=10)
    plt.subplot(1, n, i+1)

    plt.imshow(np.transpose(image_grid.numpy(), (1, 2, 0)))
plt.show()

在这里插入图片描述

2.1.5 训练转移

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2.1.6 LeNet网络模型搭建

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1,6,3,stride=1,padding=1),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6,16,5,stride=1,padding=1),
            nn.MaxPool2d(2,2)
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(576,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )
    def forward(self,x):
        X = self.conv(x)
        X = self.flatten(X)
        out = self.fc(X)
        return out

2.1.7 模型训练函数

def train(network,optimizer,loss_fn):

    losses = []
    iteration = 0

    epochs = 20

    for epoch in range(epochs):
        loss_sum = 0
        for i, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)

            pred = network(X)
            loss = loss_fn(pred, y)

            loss_sum += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_loss = loss_sum / len(train_dataloader)
        losses.append(mean_loss)
        iteration += 1
        print(f"Epoch {epoch+1} loss: {mean_loss:>7f}")

    # 训练完毕保存最后一轮训练的模型
    torch.save(network.state_dict(), "model.pth")

    # 绘制损失函数曲线
    plt.xlabel("Epochs")
    plt.ylabel("Loss Value")
    plt.plot(list(range(iteration)), losses)
    plt.show()
network = LeNet()
network.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=network.parameters(), lr=0.001, momentum=0.9)

if os.path.exists('model.pth'):
    network.load_state_dict(torch.load('model.pth'))
else:
    train(network,optimizer,loss_fn)

在这里插入图片描述
在这里插入图片描述

2.1.8 测试模型准确度

positive = 0
negative = 0
for X, y in test_dataloader:
    with torch.no_grad():#不记录梯度
        X, y = X.to(device), y.to(device)
        pred = network(X)
        for item in zip(pred, y):
            if torch.argmax(item[0]) == item[1]:

                positive += 1
            else:
                negative += 1
acc = positive / (positive + negative)
print(f"Accuracy: {acc * 100}%")

在这里插入图片描述

2.2 FGSM对抗样本测试

2.2.1 使用FGSM生成对抗样本

eps = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

for X, y in test_dataloader:
    X, y = X.to(device), y.to(device)

    X.requires_grad = True
    pred = network(X)
    network.zero_grad()
    loss = loss_fn(pred, y)
    loss.backward()

    plt.figure(figsize=(15, 8))

    plt.subplot(121)
    image_grid = torchvision.utils.make_grid(torch.clamp(X.grad.sign(), 0, 1),nrow=10)
    plt.imshow(np.transpose(image_grid.cpu().numpy(), (1, 2, 0)))

    X_adv = X + eps[2] * X.grad.sign()
    X_adv = torch.clamp(X_adv, 0, 1)

    plt.subplot(122)
    image_grid = torchvision.utils.make_grid(X_adv,nrow=10)
    plt.imshow(np.transpose(image_grid.cpu().numpy(), (1, 2, 0)))

    break
plt.show()

在这里插入图片描述

2.2.2 不同epsilon值下对模型分类准确度的影响探究

positive = 0
negative = 0
acc_list = []
for epsilon in eps:

    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)

        X.requires_grad = True
        pred = network(X)

        network.zero_grad()

        loss = loss_fn(pred, y)

        loss.backward()

        X = X + epsilon * X.grad.sign()
        X_adv = torch.clamp(X, 0, 1)

        pred = network(X_adv)
        for item in zip(pred, y):
            if torch.argmax(item[0]) == item[1]:
                positive += 1
            else:
                negative += 1

    acc = positive / (positive + negative)
    print(f"epsilon={epsilon} acc: {acc * 100}%")
    acc_list.append(acc)

plt.xlabel("epsilon")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
plt.plot(eps, acc_list, marker='o')
plt.show()

在这里插入图片描述

2.3 FGSM对抗训练

2.3.1 生成FGSM对抗数据集,并与原数据集合并

生成对抗数据集

def fgsm_attack(network, X, y, epsilon):

    delta = torch.zeros_like(X, requires_grad=True)

    loss = loss_fn(network(X + delta), y)

    loss.backward()
    grad = delta.grad.detach()
    sign_grad = grad.sign()
    perturbation = epsilon * sign_grad

    return torch.clamp(X + perturbation, 0, 1)

对抗数据集与原数据集合并并训练

def train(network,optimizer,loss_fn):

   losses = []
   iteration = 0

   epochs = 20

   for epoch in range(epochs):
       loss_sum = 0
       for i, (X, y) in enumerate(train_dataloader):
           X, y = X.to(device), y.to(device)
           #生成对抗样本
           X_adv = fgsm_attack(network, X, y, 0.5)
           #合并样本
           X_combined = torch.cat((X, X_adv), dim=0)
           y_combined = torch.cat((y, y), dim=0)
           
           optimizer.zero_grad()
           pred = network(X_combined)
           loss = loss_fn(pred, y_combined)
           loss_sum += loss.item()
           loss.backward()
           optimizer.step()

       mean_loss = loss_sum / len(train_dataloader)
       losses.append(mean_loss)
       iteration += 1
       print(f"Epoch {epoch+1} loss: {mean_loss:>7f}")

   # 训练完毕保存最后一轮训练的模型
   torch.save(network.state_dict(), "model_FGSM.pth")

   # 绘制损失函数曲线
   plt.xlabel("Epochs")
   plt.ylabel("Loss Value")
   plt.plot(list(range(iteration)), losses)
   plt.show()

在这里插入图片描述

测试不同epsilon下模型的准确度

network.load_state_dict(torch.load('model_FGSM_0.5.pth'))

在这里插入图片描述

3 小结

  • 对同一个分类模型来说,随着epsilon的增加,fgsm生成的对抗样本使得分类准确度减小
  • 在训练集添加fgsm生成的对抗样本后,训练的模型对不同epsilon的对抗样本分类准确度有显著提升。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值