生成对抗网络GAN 实现 手写体生成

一、网络结构

网络由生成器和判别器两部分组成,网络结构图如下所示:

image-20230319212303374

生成器功能: 生成接近真实样本的假样本

生成器的输入: 随机噪声

生成器的输出: 与真实样本相同大小的假样本

判别器的功能: 判断输入是真样本还是假样本

判别器的输入: 真实样本 和 生成器生成的假样本

判别器的输出: 结果是真实样本的概率值, p∈[0, 1]

二、损失函数

image-20230319205752437

损失函数第一部分: E x~pdata(x)[logD(x)]

从数据集随机抽样出来的真实样本,判别器将其判断为1

损失函数第二部分: Ez~pnoize(z)[log(1-D(G(z)))]

对于生成器而言,希望判别器将G(z)判别为真实数据,即D(G(z))的结果接近1,总体损失越小越好。

对于鉴别器而言,希望正确的判断生成器生成的假数据,即D(G(z))的结果接近0,总体损失越大越好。

由此可见,生成器和鉴别器的优化目标是相反的,即最小最大优化
在这里插入图片描述

三、生成对抗逻辑

一方面,生成器会不断提高自己造假的能力,不断进化,以骗过鉴别器。

另一方面,鉴别器也会不断提高自己的鉴别能力,不断升级,以分别真假样本。

因此,通过这种对抗训练的方式,生成器和鉴别器的能力都被不断提高,从而使得生成器能够生成较为真实的假样本。

其过程大致如下:

  1. v1版本的Generator只能生成较为模糊的图片,v1版本的Discriminator很容易鉴别出它是假样本。

  2. 然后,Generator升级为v2版本,能生成有颜色有眼睛的图片。此时,已经能够骗过v1版本的Discriminator。

  3. 接着,Discriminator也升级为v2版本,学习到了如何鉴别v2版Generator生成的假样本。

  4. 然后,Generator再继续升级… 接着,Discriminator也跟着升级…

image-20230319210922681

四、案例实现

利用MINIST数据集,训练一个能生成手写体数字的生成器。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 读取数据 并对数据进行归一化[-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1];将图片改为channel*height*width
    transforms.Normalize(0.5, 0.5)  # 标准化至[-1, 1];规定均值和标准差
])

# 训练集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)


# 定义生成器
# 生成器的输入:长度为100的随机噪声(正太分布随机数)
# 生成器的输出:生成的图片(1*28*28)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),  # 输入为100维的随机噪声,输出为256维的特征
            nn.ReLU(),  # 激活函数
            nn.Linear(256, 512),
            nn.ReLU(),  # 激活函数v
            nn.Linear(512, 28 * 28),  # 输出为28*28的图片
            nn.Tanh()  # 激活函数
        )

    def forward(self, input):
        output = self.main(input)
        output = output.view(-1, 28, 28)  # 将输出的图片reshape为1*28*28
        return output


# 定义判别器
# 判别器的输入:真实图片(1*28*28) / 生成的图片(1*28*28)
# 判别器的输出:判断图片为真的概率(0~1)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),  # 输入为28*28的图片,输出为512维的特征
            nn.LeakyReLU(),  # 激活函数
            nn.Linear(512, 256),
            nn.LeakyReLU(),  # 激活函数
            nn.Linear(256, 1),
            nn.Sigmoid()  # 激活函数
        )

    def forward(self, input):
        input = input.view(-1, 28 * 28)  # 将输入的图片reshape为28*28
        output = self.main(input)
        return output


# 定义损失函数
loss_fn = torch.nn.BCELoss()  # 二分类交叉熵损失函数

# 定义优化器
# 生成器的优化器
gen = Generator().to(device)
g_optim = optim.Adam(gen.parameters(), lr=0.0001)
# 判别器的优化器
dis = Discriminator().to(device)
d_optim = optim.Adam(dis.parameters(), lr=0.0001)


# 绘图函数
def plot_img(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()


test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []

writer = SummaryWriter(log_dir='runs/logs')

# 训练
for epoch in range(50):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)

    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        # 训练判别器
        d_optim.zero_grad()
        # 1.1 真实图片
        real_output = dis(img)  # 判别器判断真实图片的概率
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实图片的损失
        d_real_loss.backward()  # 反向传播

        # 1.2 生成图片
        fake_img = gen(random_noise)  # 生成图片
        fake_output = dis(fake_img.detach())  # 判别器判断生成图片的概率
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 生成图片的损失
        d_fake_loss.backward()  # 反向传播

        # 1.3 损失函数
        d_loss = d_real_loss + d_fake_loss  # 判别器的损失
        d_optim.step()  # 更新参数

        # 训练生成器
        g_optim.zero_grad()
        fake_output = dis(fake_img)  # 判别器判断生成图片的概率
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成图片的损失
        g_loss.backward()  # 反向传播
        g_optim.step()  # 更新参数

        with torch.no_grad():
            d_epoch_loss += d_loss.item()
            g_epoch_loss += g_loss.item()

    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch: {}, Step: {}, D_loss: {:.4f}, G_loss: {:.4f}'.format(epoch, step, d_epoch_loss, g_epoch_loss))
        writer.add_scalar('D_loss', d_epoch_loss, epoch)
        writer.add_scalar('G_loss', g_epoch_loss, epoch)
        plot_img(gen, test_input)

五、实现结果

网络训练过程中,随着epoch的增多,生成器生成的结果越来越真实。具体图像如下所示:

Epoch1Epoch10Epoch30Epoch50
image-20230320104017547image-20230320104226399image-20230320104319127image-20230320104410955
D_lossG_loss
image-20230320104508284image-20230320104522397

如上图所示,由两个损失曲线可以看出:

对于生成器而言,它的损失一开始较大,在训练过程中呈减小的趋势;

原因分析:生成器是从随机噪声开始学习,因此最初的时候损失会比较大。

对于鉴别器而言,它的损失一开始较小,在训练过程中呈增大的趋势;

原因分析:在最开始的时候,鉴别器很容易判断学得不好的生成器生成的假样本,但生成器在对抗训练中不断改进提升,所以鉴别器鉴别真假样本越来越难。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值