GAN详解与PyTorch MINIST手写数字生成实战

GAN简介

GAN(Generative Adversarial Nets) 用中文来说就是 生成对抗网络,它是Ian J. Goodfellow在2014年提出的一种深度学习网络模型。它包含两个模型:生成模型和辨别模型。生成模型是用来捕捉真实数据分布来生成符合原始数据分布的新的数据,辨别模型是用来辨别真实数据和生成模型生成的数据。生成模型的目的是为了来让辨别模型犯错,辨别模型的目的是为了区分生成数据和真实数据。这就好像是两个模型在互相对抗,在对抗中不断吸取经验从而来让自身得到提升,可以类比博弈论中的两人对抗游戏。GAN 可以通过**MLP(多层感知机)**来进行误差反传进行训练。这样就比使用马尔科夫链或者对近似推理过程展开更加简单。由于GAN在生成图像方面有很好的效果,所以得到了很广泛的应用,比如生成名人小时候的照片,将真实人物变成卡通形式,甚至还可以生成世界上不存在的人的脸部照片。

GAN论文原理

    对于生成器,以生成图片为例,我们需要输入一个噪声 z z z,就类似一个一百维的变量吧,然后 z z z通过从真实数据 x x x中学习到的分布 p g p_g pg去进行映射,就可以生成一张图片 G ( z ) G(z) G(z)
    对于辨别器D,我们就是去对生成数据和真实数据进行分类,类似一个两类的分类器。设 D ( x ) D(x) D(x)表示 x x x是来自真实数据而不是 p g p_g pg的概率。
    根据GAN的需求,我们需要尽可能让辨别器能够辨别真实数据和生成数据并且让生成器生成数据让辨别器尽可能犯错。简而言之就是最大化 l o g ( D ( x ) ) log(D(x)) log(D(x)),最小化 l o g ( 1 − D ( G ( x ) ) ) log(1-D(G(x))) log(1D(G(x))),所以我们可以得到如下公式:
在这里插入图片描述
这样 D D D G G G就在好像进行两人对抗游戏。
在这里插入图片描述
    这是GAN的训练过程,其中绿色的线为生成器生成的数据,黑色的点为真实数据,蓝色的点为辨别器的结果。从a-b-c-d可以看出,生成器生成的数据在不断向真实数据拟合,辨别的结果也在不断改变,最后黑色的点和绿色的线完全拟合时辨别器无法辨别真实数据和生成数据时,此时辨别器的曲线值为0.5(0表示生成数据,1表示真实数据),就无法通过辨别器的值来辨别数据来源。

下面介绍GAN的算法:
在这里插入图片描述
    这里比较重要的就是 k k k的取值,它会关系到我们模型训练的好坏。 k k k的取值不能太小,也不能太大。如果 k k k的取值太小,这样每次更新生成器后辨别器得不到充分的更新,无法很好辨别真实数据和生成数据,这时就算不更新生成器也能糊弄辨别器,此时更新生成器的意义不大;如果 k k k的取值太大,意味着生成器更新后辨别器会被更新得很好,此时上述生成器梯度公式中 l o g ( 1 − D ( G ( z ( i ) ) ) ) log(1-D(G(z^{(i)}))) log(1D(G(z(i))))就是0,这是就对0求梯度,这样在生成模型的更新上会有困难。这里我们类比一个例子更好理解:假设辨别器就是警察,生成器就是造假者。如果警察太厉害,那么造假者生产一点假钞就被一锅端了,那么造假者就没法赚到钱,不能去进一步改进工艺;如果警察太无力,无法比较好分辨真钞和假钞,那么造假者随便生产点东西都能赚到钱,这样生产者就不会想着去改进工艺。所以两方面都不行,最好的就是两方实力相当,这样都能互相促进进步。

MINIST手写数字生成实战

这是一个利用手写数据集进行训练得到的GAN,生成器接收随机噪声作为输入,然后输出一张手写数字图像;判别器的输入则是两幅图像,分别是真的手写数字图像和生成器生成的假图像,然后输出对这两幅图像的判别结果。
在这里插入图片描述

1、导入MINIST数据集。

train_data = dataloader.DataLoader(datasets.MNIST(root='data/', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
]), download=True), shuffle=True, batch_size=batch_sz)

2、构建辨别器和生成器

辨别器:

class discrimination(nn.Module):
    def __init__(self):
        super(discrimination, self).__init__()
        self.hidden0 = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
        )

        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
        )

        self.out = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

这里我们对辨别器网络构造采用4层,前三层用LeakyReLu,最后一层用sigmoid。使用LeakyReLu是因为不会将零以下的数全部置为零,所以使用LeakyReLU 激活函数相比使用ReLU 能够更好地使梯度流过网络。使用sigmoid是因为能够将输出值约束在区间[0, 1]


生成器:

class generate(nn.Module):
    def __init__(self):
        super(generate, self).__init__()
        self.hidden0 = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )

        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )

        self.out = nn.Sequential(
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

生成器前面三层和辨别器一样,最后一层采用tanh激活函数,是为了与对MNIST 数据进行的归一化同步,以将其值转换到[-1, 1] 中,以便判别器始终获取数据点处于相同值域的数据集。

3、训练模型

def train_discriminator(optimizer, loss_fn, real_data, fake_data):
    optimizer.zero_grad()

    discriminator_real_data = discriminator(real_data)
    loss_real = loss_fn(discriminator_real_data, torch.ones(real_data.size(0), 1).to(device))
    loss_real.backward()

    discriminator_fake_data = discriminator(fake_data)
    loss_fake = loss_fn(discriminator_fake_data, torch.zeros(fake_data.size(0), 1).to(device))
    loss_fake.backward()

    optimizer.step()

    return loss_real + loss_fake, discriminator_real_data, discriminator_fake_data

def train_generator(optimizer, loss_fn, fake_data):
    optimizer.zero_grad()

    output_discriminator = discriminator(fake_data)
    loss = loss_fn(output_discriminator, torch.ones(output_discriminator.size(0), 1).to(device))
    loss.backward()
    optimizer.step()
    return loss

for epoch in range(num_epoch):

    for train_idx, (input_real_batch, _) in enumerate(train_data):
        real_data = images2vectors(input_real_batch).to(device)
        generated_fake_data = generator(noise(real_data.size(0))).detach()
        d_loss, discriminated_real, discriminated_fake = train_discriminator(d_optimizer, loss_fn, real_data,
                                                                             generated_fake_data)

        generated_fake_data = generator(noise(real_data.size(0)))
        g_loss = train_generator(g_optimizer, loss_fn, generated_fake_data)

        if train_idx == len(train_data) - 1:
            print(epoch, 'd_loss: ', d_loss.item(), 'g_loss: ', g_loss.item())

train_discriminator和train_generator是为了分别对辨别器和生成器求loss,并进行反向传播、参数优化。辨别器涉及到真实数据和生成数据俩方面的误差(上面图中有提到),所以将他们相加起来。
源代码见 MINIST手写数字生成


小伙伴喜欢文章的话记得 点赞加关注 哦,后面会更新其他深度学习的文章。
如果有什么写得有问题的地方希望大家能值出,谢谢。

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: minist手写数字识别pytorch是一种基于PyTorch框架的手写数字识别模型。该模型可以识别到9的手写数字,并且在训练集和测试集上都有很好的表现。它的实现过程包括数据预处理、模型构建、模型训练和模型测试等步骤。通过使用PyTorch框架,可以方便地实现深度学习模型,并且可以利用GPU加速训练过程,提高模型的训练效率。 ### 回答2: 在现代机器学习的技术中,手写数字识别是一个相对简单的问题。然而,它的理论和技术都是非常有价值的。这个问题的目标是给机器一个图像,让它预测图像上的数字。这个任务对于许多现实世界的应用非常有用,例如自动识别支票或信用卡上的数字等。 Minist手写数字识别是一个流行的经典问题,它的目标是识别0-9的手写数字。这项任务已经在经典计算机视觉算法的研究中经常出现,被广泛使用,并且是许多机器学习算法和模型的基础。在这里,我们将使用PyTorch来实现这个任务。 首先,需要下载Minist数据集并准备数据。Minist数据集包含了70,000张28x28的灰度图像,每张图像代表了0到9之间的一个数字。数据集被分成了两个部分:60,000张图像用于训练,剩下的10,000张图像用于测试。 我们将使用PyTorch来构建一个卷积神经网络(CNN)来解决这个问题。这个CNN包括两个卷积层和两个全连接层。卷积层用于提取图像特征,它们通过卷积和池化操作将图像转换为低维的特征表示。全连接层则将这些特征映射到数字标签。 在训练CNN之前,我们需要对图像进行预处理和标准化。然后,我们将定义损失函数,优化器和学习率计划,以便在训练期间或在测试期间为CNN提供足够的准确性。 最后,我们将使用测试数据集来评估CNN的性能。为了更好的评估模型的性能,我们还可以使用k-fold交叉验证技术,以确保我们的CNN是健壮和可靠的。 总而言之,使用PyTorch来实现Minist手写数字识别是一个非常有趣和有收获的挑战。它不仅可以帮助我们了解机器学习中的经典问题,还可以帮助我们掌握深度学习技术和PyTorch的应用。 ### 回答3: Minist手写数字识别是深度学习领域中一个经典的问题。它的主要目标是通过机器学习的方法识别并分类手写数字。传统的机器学习方法使用手动设计的特征,但这种方法在处理高维、非线性数据时效果不理想。近年来,深度学习的发展使得自动学习特征成为可能,从而为Minist手写数字识别提供了新的解决方案。 在深度学习领域中,PyTorch是一种非常流行的框架,具有很强的灵活性和扩展性,被广泛用于各种机器学习问题的解决。PyTorch可以支持多种神经网络模型,包括卷积神经网络(CNN),循环神经网络(RNN)等。在Minist手写数字识别中,最常用的是CNN模型,因为CNN模型具有非常好的图像处理能力。而PyTorch中的CNN模型则可以通过简单的代码实现,下面是一个简单的CNN模型的代码: class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) 这个CNN模型包括两个卷积层(conv1和conv2)和两个全连接层(fc1和fc2)。其中,卷积层和全连接层都是通过PyTorch中的类来定义的。在forward()函数中,卷积和像素池化操作被串连在一起,用于从图像中提取特征。这些特征被展平并传递到全连接层中进行分类。 在PyTorch中,使用Minist手写数字数据集进行训练非常简单,因为PyTorch内置了MNIST数据集,并且提供了数据加载和预处理函数。使用该数据集可以轻松地训练CNN模型并进行手写数字识别。 综上所述,基于PyTorch实现Minist手写数字识别的CNN模型是一种相对容易的方法。使用PyTorch的灵活性和扩展性,可以定义并训练高性能的模型,并且可以通过各种方式来提高模型的准确性。该模型还可以与其他深度学习技术结合使用,例如迁移学习和增强学习,以进一步提高性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值