深度学习——第五篇(生成式对抗网络)

目录

1 概述

1.1 生成式对抗网络

1.2 GAN的应用

2 GAN 算法实现数组拟合

参考文献


1 概述

1.1 生成式对抗网络

生成式对抗网络(Generative adversarial networks,GAN) 主要包括生成器( generator) 判别器( discriminator) 。其基本模型结构如图1所示。生成器通过学习真实图像分布(概率分布)从而使生成的图像更加真实,以欺骗判别器; 判别器(是一个二分类器)则需要对接收的图像进行真假判别。 在训练过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假。随着训练的迭代,生成器和判别器在不断地进行对抗,期望网络达到一个纳什均衡状态生成器生成的图像接近于真实的图像分布,而判别器识别不出真假图像,或者说判别器每次输出的概率基本为1/2,即判别器相当于随机猜测样本是真是假。[1]

图1 GAN基本模型结构[1]

GAN模型的训练过程如下:

在一个训练周期内,首先训练鉴别器。真实数据的标签为1,生成数据的标签为0,将这两种数据合起来让鉴别器去判断,训练鉴别器让其能准确分类;鉴别器训练完之后,将其固定,训练生成器,只给鉴别器生成的数据,训练生成器让鉴别器“误判”为1。

1.2 GAN的应用

  • 超分辨率图像的生成
  • 文本描述生成图像
  • 视频帧预测
  • 艺术风格的迁移
  • 检测恶意代码[2]
  • ......

2 GAN 算法实现数组拟合

目标是让生成器能生成与真实数组一样分布的数据。

导入需要的包

import torch
from torch import nn

import math
import matplotlib.pyplot as plt

准备训练数据(用于鉴别器的训练)

torch.manual_seed(111) # 初始化随机数种子

# 准备训练数据
# 由一对(x1,sin(x1))组成

train_data_length=2048
train_data=torch.zeros((train_data_length,2))
train_data[:,0]=2*math.pi*torch.rand(train_data_length) # 生成0-2pi的随机值
train_data[:,1]=torch.sin(train_data[:,0])
train_labels=torch.zeros(train_data_length) # 由于是无监督学习,因此标签没用,但是pytorch的数据加载器会用到
train_set=[(train_data[i],train_labels[i]) for i in range(train_data_length)] # 数据加载器,每一行是data和labels
plt.plot(train_data[:,0],train_data[:,1],'.')
plt.show()
图2 真实数据

创建数据加载器

batch_size=32
# 创建名为train_loader的数据加载器,返回32个样本
train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)

创建生成器和判别器网络

# 为生成器和判别器创建神经网络

# 判别器类继承自nn.module的类来表示
# 判别器是一个具有二维输入和一维输出的模型。从真实数据或者生成器中接收一个样本,并输出该样本属于真实训练数据的概率
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Linear(2,256),
            nn.ReLU(),
            nn.Dropout(0.3), # 避免过拟合
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64,1),
            nn.Sigmoid(), # 0-1 概率
        )
    def forward(self,x):
        output=self.model(x)
        return output

# 生成器实现
# 生成器是将隐空间的样本作为其输入,并生成与训练集中数据相类似的数据的模型
# 二维输入(接收随机点(z1,z2))和一个二维输出,输出(x1_hat,x2_hat)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Linear(2,16),
            nn.ReLU(),
            nn.Linear(16,32),
            nn.ReLU(),
            nn.Linear(32,2) # 生成任何值
        )
    def forward(self,x):
        output=self.model(x)
        return output

discriminator=Discriminator() # 鉴别器网络模型
generator=Generator() # 生成器网络模型

训练两个模型

每次先用真实数据和生成数据训练判别器模型(让判别器尽可能正确区分真伪),然后将判别器固定,用生成数据训练生成器(尽可能让鉴别器都判断为真,即“以假乱真”)

# 训练模型
lr=0.001
num_epochs=300
loss_function=nn.BCELoss() # 二元交叉熵损失函数

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)  # 判别器优化器
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)  # 生成器优化器

for epoch in range(num_epochs):
    for n,(real_samples,_)in enumerate(train_loader):
        # 每一次给出一个batch,直到所有样本都用完
        
        # 训练判别器的数据(生成器是固定的):
        real_samples_labels=torch.ones((batch_size,1))
        latent_space_samples=torch.randn((batch_size,2)) # 隐藏空间的样本batch
        generated_samples=generator(latent_space_samples) # 生成器生成的样本batch
        generated_samples_labels=torch.zeros((batch_size,1))
        all_samples=torch.cat((real_samples,generated_samples)) # 所有的样本
        all_samples_labels=torch.cat((real_samples_labels,generated_samples_labels)) # 1 0
        # 训练判别器
        optimizer_discriminator.zero_grad()
        output_discriminator=discriminator(all_samples) # 鉴别器输出(输入了包括真实样本的所有样本)
        loss_discriminator=loss_function(output_discriminator,all_samples_labels) # 通过结果看是否能够正确区分01
        loss_discriminator.backward(retain_graph=True) # 不释放图中的中间变量(需要多次反向传播)
        optimizer_discriminator.step()

        # 训练生成器的数据(鉴别器是固定的)
        latent_space_samples=torch.randn((batch_size,2))
        # 训练生成器
        optimizer_generator.zero_grad()
        generator_samples=generator(latent_space_samples)
        output_discriminator_generated=discriminator(generated_samples) # 只鉴别生成的
        loss_generator=loss_function(output_discriminator_generated,real_samples_labels) # 训练生成器以使鉴别器的输出都是1
        loss_generator.backward(retain_graph=True)
        optimizer_generator.step()

        if epoch % 10 ==0 and n==batch_size-1:
            print(f"Epoch:{epoch} Loss D:{loss_discriminator}")
            print(f"Epoch:{epoch} Loss G:{loss_generator}") 

生成器模型输出

latend_space_samples=torch.randn(100,2)
generated_samples=generator(latend_space_samples) # 模型生成的输出
enerated_samples=generated_samples.detach() # 将张量从当前计算图中分离出来,不参与梯度计算,通常用于模型输出。
plt.plot(enerated_samples[:,0],enerated_samples[:,1],".")
图3 生成器生成数据

参考文献

[1] 汪美琴,袁伟伟,张继业.生成对抗网络GAN的研究综述[J].计算机工程与设计,2021,42(12):3389-3395.DOI:10.16208/j.issn1000-7024.2021.12.012.

[2] 程显毅,谢璐,朱建新,等.生成对抗网络GAN综述[J].计算机科学,2019,46(03):74-81.

[3] Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化 – 拓端

  • 21
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值