GAN生成对抗网络原理推导(附加代码)

一、原理        

GAN(Generative Adversarial Network,生成对抗网络)是一种深度学习模型,由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络彼此对抗,通过对抗训练的方式来学习生成逼真的数据样本。

1. 生成器(Generator):生成器的目标是学习生成与真实数据样本相似的数据。它接收一个随机噪声或其他形式的输入,然后通过一系列的神经网络层逐步将这个输入转换为所需的输出数据。生成器的目标是使其生成的数据尽可能地接近真实数据分布。

2. 判别器(Discriminator):判别器的任务是区分生成器生成的假数据和真实数据。它接收生成器生成的样本和真实数据样本作为输入,并输出一个概率,表示输入是真实数据的概率。判别器的目标是将生成器生成的假数据和真实数据区分开来。

GAN的训练过程基于两个网络之间的对抗。在训练过程中,生成器试图欺骗判别器,生成尽可能逼真的数据样本,而判别器则试图区分生成器生成的假数据和真实数据。通过这种对抗性的训练,生成器和判别器逐渐改进自己的性能,最终生成器可以生成与真实数据相似的样本。

GAN的训练过程可以概括为以下步骤:

1. 初始化网络参数:生成器和判别器的参数被随机初始化。

2. 交替训练:
   (生成器训练):生成器接收随机噪声作为输入,并生成数据样本。生成器的目标是使生成的样本尽可能地接近真实数据。生成器生成的样本通过判别器,然后根据判别器的反馈来更新生成器的参数。
   (判别器训练):判别器接收真实数据样本和生成器生成的假数据样本,并尝试将它们区分开来。判别器的目标是正确地分类真实数据和假数据。判别器的参数通过梯度下降来更新,以减少真实数据和假数据之间的分类误差。

3. 对抗训练:生成器和判别器交替训练,相互对抗,直到达到某个停止条件(如训练轮数达到预设值)。

4. 评估生成器:生成器训练完成后,可以使用它来生成新的数据样本。生成器生成的样本可以通过一些指标来评估其质量,例如与真实数据的相似度或用于某个任务的性能。

二、目标

我们的训练数据x xx是来自真实分布对应图中 P ( d a t a ) 

我们记作Pdata,训练数据都是从Pdata中采样得来(图中上半部分的x)。

而我们从简单的概率分布中抽样P(z)如正态分布 ,让所得的样本经过一个神经网络 G(z),得到一个新的样本x 这个样本就来自我们的需要求解的概率分布,我们记作P g。

然后将两个x给神经网络D(x)判断真伪,让它区分这个x是来自P data 还是P g,其输出样本来自Pdata的概率。依据所得信息使用梯度下降更新神经网络参数,G(z)也是如此。

而G(z)被称为生成器( 用于生成样本 ) ,D ( x ) 被称为判别器用于判别样本真伪
 

目标函数:

损失函数来自判别器和生成器

1、判别器

当样本来自P data,我们要让所得的概率越大越好;当样本来自p g,我们要让其概率越小越好,即              

将最小化换为最大化

所以单个样本判别器的损失函数可以写成

对于所有样本N,我们希望均值最大

写成期望形式得到判别器的损失函数( p_{data}x∼p data 表示样本来自真实分布)

2、生成器

它希望生成的样本让判别器判别为真的概率越大越好,所以直接设计成(将最大写成最小)

所以最终的目标函数可以写成

三、求最优解

 得到了目标函数,我们很显然还需要证明其存在最优解。并且最优解的P g是否和P data无限接近

先求里层关于D求最大值

要求积分最大,就是要求里面的没一个最大

求导整理后得

将其带入目标函数,并且关于外层G求最小

由此可见目标函数最优值能够让Pg逼近Pdata,并且当期相等时,有

也就是判断器再也无法判断出样本来自Pg还是Pdata。

四、代码完整版

# -*- coding: utf-8 -*-
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm  import tqdm
import matplotlib.pyplot as plt

class Generate_Model(torch.nn.Module):
    '''
    生成器
    '''
    def __init__(self):
        super().__init__()
        self.fc=torch.nn.Sequential(
            torch.nn.Linear(in_features=128,out_features=256),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=256,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=784),
            torch.nn.Tanh()
        )
    def forward(self,x):
        x=self.fc(x)
        return x

class Distinguish_Model(torch.nn.Module):
    '''
    判别器
    '''
    def __init__(self):
        super().__init__()
        self.fc=torch.nn.Sequential(
            torch.nn.Linear(in_features=784,out_features=512),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=512,out_features=256),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=256,out_features=128),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=128,out_features=1),
            torch.nn.Sigmoid()
        )
    def forward(self,x):
        x=self.fc(x)
        return x
def train():
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #判断是否存在可用GPU
    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=0.5, std=0.5)
    ]) #图片标准化
    train_data = MNIST("./data", transform=transformer,download=True) #载入图片
    dataloader = DataLoader(train_data, batch_size=64,num_workers=4, shuffle=True) #将图片放入数据加载器

    D = Distinguish_Model().to(device) #实例化判别器
    G = Generate_Model().to(device) #实例化生成器

    D_optim = torch.optim.Adam(D.parameters(), lr=1e-4) #为判别器设置优化器
    G_optim = torch.optim.Adam(G.parameters(), lr=1e-4) #为生成器设置优化器

    loss_fn = torch.nn.BCELoss() #损失函数

    epochs = 100 #迭代100次
    for epoch in range(epochs):
        dis_loss_all=0 #记录判别器损失损失
        gen_loss_all=0 #记录生成器损失
        loader_len=len(dataloader) #数据加载器长度
        for step,data in tqdm(enumerate(dataloader), desc="第{}轮".format(epoch),total=loader_len):
            # 先计算判别器损失
            sample,label=data #获取样本,舍弃标签
            sample = sample.reshape(-1, 784).to(device) #重塑图片
            sample_shape = sample.shape[0] #获取批次数量
            #从正态分布中抽样
            sample_z = torch.normal(0, 1, size=(sample_shape, 128),device=device)

            Dis_true = D(sample) #判别器判别真样本

            true_loss = loss_fn(Dis_true, torch.ones_like(Dis_true)) #计算损失

            fake_sample = G(sample_z) #生成器通过正态分布抽样生成数据
            Dis_fake = D(fake_sample.detach()) #判别器判别伪样本
            fake_loss = loss_fn(Dis_fake, torch.zeros_like(Dis_fake)) #计算损失

            Dis_loss = true_loss + fake_loss #真假加起来
            D_optim.zero_grad()
            Dis_loss.backward() #反向传播
            D_optim.step()

            # 生成器损失
            Dis_G = D(fake_sample) #判别器判别
            G_loss = loss_fn(Dis_G, torch.ones_like(Dis_G)) #计算损失
            G_optim.zero_grad()
            G_loss.backward() #反向传播
            G_optim.step()
            with torch.no_grad():
                dis_loss_all+=Dis_loss #判别器累加损失
                gen_loss_all+=G_loss #生成器累加损失
        with torch.no_grad():
            dis_loss_all=dis_loss_all/loader_len
            gen_loss_all=gen_loss_all/loader_len
            print("判别器损失为:{}".format(dis_loss_all))
            print("生成器损失为:{}".format(gen_loss_all))
        torch.save(G, "./G.pth") #保存模型
        torch.save(D, "./D.pth") #保存模型
if __name__ == '__main__':
    train() #训练模型
    model_G=torch.load("./G.pth",map_location=torch.device("cpu")) #载入模型
    fake_z=torch.normal(0,1,size=(10,128))  #抽样数据
    result=model_G(fake_z).reshape(-1,28,28)  #生成数据
    result=result.detach().numpy()

    #绘制
    for i in range(10):
        plt.subplot(2,5,i+1)
        plt.imshow(result[i])
        plt.gray()
    plt.show()

五、结果

  • 18
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值