基于GAN的手写数字生成系统

摘要:本研究提出了一种基于生成对抗网络(GAN)的手写数字生成系统。系统利用MNIST数据集训练GAN模型,实现手写数字的高质量生成。GAN模型由生成器和判别器组成,通过两者的对抗训练,使生成器能够生成逼真的手写数字图像。本研究还使用PYQt5开发了一个可视化界面,用户可以直观地观察生成过程和结果。实验结果表明,该系统能够有效生成高质量的手写数字图像,并在一定程度上增强了手写数字数据的多样性。

关键词:MNIST、GAN、PyQt5

1 研究背景意义

在人工智能和机器学习领域,生成模型的研究近年来取得了显著进展。生成对抗网络(GAN)作为一种重要的生成模型,因其在图像生成和数据增强方面的优越性能,受到了广泛关注。手写数字生成是其中一个典型的应用场景,能够在不增加额外标注数据的情况下,为手写数字识别系统提供更多的训练样本,进而提高识别系统的性能和鲁棒性。

MNIST数据集作为手写数字识别领域的经典数据集,为研究生成对抗网络提供了良好的实验基础。通过GAN模型生成新的手写数字图像,不仅可以丰富现有的数据集,还可以在数据稀缺的情况下,生成高质量的训练样本,解决数据不足的问题。

本研究基于GAN模型和MNIST数据集,开发了一个手写数字生成系统,并使用PYQt5进行可视化。通过该系统,用户可以直观地观察生成过程,评估生成结果的质量。这不仅有助于进一步理解GAN的工作原理,还为手写数字识别系统的研究和应用提供了有力的支持。

2 技术介绍

生成对抗网络 (Generative Adversarial Network, GAN) 是由 Goodfellow 1 于 2014 年提出的一种对抗网络。这个网络框架包含两个部分,一个生成模型 (generative model) 和一个判别模型 (discriminative model)。其中,生成模型可以理解为一个伪造者,试图通过构造假的数据骗过判别模型的甄别;判别模型可以理解为一个警察,尽可能甄别数据是来自于真实样本还是伪造者构造的假数据。两个模型都通过不断的学习提高自己的能力,即生成模型希望生成更真的假数据骗过判别模型,而判别模型希望能学习如何更准确的识别生成模型的假数据。

GAN 由两部分构成,一个生成器 (Generator) 和一个判别器 (Discriminator)。对于生成器,我们需要学习关于数据 𝑥 的一个分布 𝑝𝑔,首先定义一个输入数据的先验分布 𝑝𝑧(𝑧),其次定义一个映射 𝐺(𝑧;𝜃𝑔):𝑧→𝑥。对于判别器,我们则需要定义一个映射 𝐷(𝑥;𝜃𝑑) 用于表示数据 𝑥 是来自于真实数据,还是来自于 𝑝𝑔。GAN 的网络框架如下图所示 :

3 功能演示

3.1实时训练模型

在设置好训练轮数(Epoch)后,模型开始进行实时训练。训练过程中,界面上会显示实时的训练结果,同时控制台会输出损失值、准确率等参数。此外,训练模型会被实时保存。

3.2 加载模型进行结果预测

点击“Load Model”按钮,从文件夹中选择训练好的生成器模型,然后开始绘制生成的手写数字。生成的图像将被保存到SaveImg文件夹中,并在界面上显示。

4 核心代码讲解

4.1 判别器训练

def train_D(self,inputs,targets):
        outputs = self.D.forward(inputs)
        loss = self.loss_F(outputs, targets)
        self.out_D.append(outputs)
        self.tar_D.append(targets)
        self.counter+=1
        if (self.counter % 2000 == 0):  # % 10表示除以10之后的余数,当计数器为10、20、30等时,余数为0。

            self.out_D=[x.item() for x in self.out_D]
            self.out_D=list(np.where(np.array(self.out_D)>0.5,1,0).astype(int))
            self.tar_D=[x.item() for x in self.tar_D]
            self.tar_D = list(np.array(self.tar_D).astype(int))
            acc=metrics.accuracy_score(y_true=self.tar_D,y_pred=self.out_D)
            pre = metrics.precision_score(y_true=self.tar_D,y_pred=self.out_D)
            rec = metrics.recall_score(y_true=self.tar_D,y_pred=self.out_D)
            f1 = metrics.f1_score(y_true=self.tar_D,y_pred=self.out_D)
            mcc = metrics.matthews_corrcoef(y_true=self.tar_D,y_pred=self.out_D)
            print(acc)
            self.out_D = []
            self.tar_D = []
            self.loss_info.emit(f"train_D:当前counter:{self.counter},当前loss={loss.item()},Acc={acc},pre={pre},rec={rec},f1={f1},mcc={mcc}")
        self.D.optimiser.zero_grad()
        loss.backward()
        self.D.optimiser.step()

实现了生成对抗网络(GAN)中判别器(Discriminator)的训练过程。在每次训练中,判别器对输入数据进行前向传播并计算损失值,同时将输出和目标值存储起来。每训练2000次,代码会将存储的输出值和目标值进行二值化,并计算准确率、精确率、召回率、F1分数和Matthews相关系数等评估指标,打印并显示这些指标信息。随后,代码将梯度归零,进行反向传播并更新模型参数。

4.2 生成器训练

def train_G(self,inputs,targets):
        g_output = self.G.forward(inputs)
        # 输入鉴别器
        d_output = self.D.forward(g_output)
        # 计算损失值
        loss = self.D.loss_function(d_output, targets)
        # 每训练20次增加计数器
        self.counter += 1
        if (self.counter % 2000 == 0):
            self.loss_info.emit(f"train_G:当前counter:{self.counter},当前loss:{loss.item()}")
        if self.counter%1000==0:
            torch.save(self.G, 'current_G.pt')
            self.loss_info.emit("Draw")
        self.G.optimiser.zero_grad()
        loss.backward()
        self.G.optimiser.step()

用于训练生成对抗网络(GAN)的生成器。具体来说,它通过生成器 `G` 生成图像,并将这些图像输入到鉴别器 `D` 中计算损失值。每训练2000次会发出当前计数器和损失值的信号,每训练1000次会保存当前生成器模型。接着,代码会清空生成器的梯度,反向传播损失,并更新生成器的参数。

4.3 组合训练

def run(self) -> None:
        self.D.train()
        self.G.train()
        for epoch in range(self.epochs):
            for step, (images, labels) in enumerate(self.train_loader):
                image_data_tensor = images.view(-1)
                self.train_D(image_data_tensor,torch.FloatTensor([1.0]))
                self.train_D(self.G.forward(generate_random(100)).detach(), torch.FloatTensor([0.0]))
                self.train_G(generate_random(100), torch.FloatTensor([1.0]))
        torch.save(self.G, 'GAN_Digits_G.pt')

用于训练生成对抗网络(GAN)。具体功能如下:

1. 将鉴别器 `D` 和生成器 `G` 设置为训练模式。

2. 在指定的训练周期数 `epochs` 内,遍历训练数据 `train_loader`。

3. 对于每一个批次的图像 `images` 和标签 `labels`:

   - 将图像数据展平为一维张量。

   - 使用真实图像数据训练鉴别器,使其输出接近1(表示真实)。

   - 使用生成器生成随机图像,并使用这些假图像数据训练鉴别器,使其输出接近0(表示虚假)。

   - 使用随机噪声训练生成器,使其输出被鉴别器判别为接近1(表示生成的图像尽可能真实)。

4. 在训练结束后,保存训练好的生成器模型 `G` 到文件 `GAN_Digits_G.pt`。

关注GZH:阿欣Python与机器学习,发送【源码】即可获取下载方式

  • 12
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值