目录
1 概述
1.1 生成式对抗网络
生成式对抗网络(Generative adversarial networks,GAN) 主要包括生成器( generator) 与判别器( discriminator) 。其基本模型结构如图1所示。生成器通过学习真实图像分布(概率分布)从而使生成的图像更加真实,以欺骗判别器; 判别器(是一个二分类器)则需要对接收的图像进行真假判别。 在训练过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假。随着训练的迭代,生成器和判别器在不断地进行对抗,期望网络达到一个纳什均衡状态:生成器生成的图像接近于真实的图像分布,而判别器识别不出真假图像,或者说判别器每次输出的概率基本为1/2,即判别器相当于随机猜测样本是真是假。[1]
![](https://img-blog.csdnimg.cn/direct/2a776c7eb97c40f181292ac399ae9df7.png)
GAN模型的训练过程如下:
在一个训练周期内,首先训练鉴别器。真实数据的标签为1,生成数据的标签为0,将这两种数据合起来让鉴别器去判断,训练鉴别器让其能准确分类;鉴别器训练完之后,将其固定,训练生成器,只给鉴别器生成的数据,训练生成器让鉴别器“误判”为1。
1.2 GAN的应用
- 超分辨率图像的生成
- 文本描述生成图像
- 视频帧预测
- 艺术风格的迁移
- 检测恶意代码[2]
- ......
2 GAN 算法实现数组拟合
目标是让生成器能生成与真实数组一样分布的数据。
导入需要的包
import torch
from torch import nn
import math
import matplotlib.pyplot as plt
准备训练数据(用于鉴别器的训练)
torch.manual_seed(111) # 初始化随机数种子
# 准备训练数据
# 由一对(x1,sin(x1))组成
train_data_length=2048
train_data=torch.zeros((train_data_length,2))
train_data[:,0]=2*math.pi*torch.rand(train_data_length) # 生成0-2pi的随机值
train_data[:,1]=torch.sin(train_data[:,0])
train_labels=torch.zeros(train_data_length) # 由于是无监督学习,因此标签没用,但是pytorch的数据加载器会用到
train_set=[(train_data[i],train_labels[i]) for i in range(train_data_length)] # 数据加载器,每一行是data和labels
plt.plot(train_data[:,0],train_data[:,1],'.')
plt.show()
![](https://img-blog.csdnimg.cn/direct/d16192bc51724b06bad4f986367c4d64.png)
创建数据加载器
batch_size=32
# 创建名为train_loader的数据加载器,返回32个样本
train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)
创建生成器和判别器网络
# 为生成器和判别器创建神经网络
# 判别器类继承自nn.module的类来表示
# 判别器是一个具有二维输入和一维输出的模型。从真实数据或者生成器中接收一个样本,并输出该样本属于真实训练数据的概率
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model=nn.Sequential(
nn.Linear(2,256),
nn.ReLU(),
nn.Dropout(0.3), # 避免过拟合
nn.Linear(256,128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128,64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64,1),
nn.Sigmoid(), # 0-1 概率
)
def forward(self,x):
output=self.model(x)
return output
# 生成器实现
# 生成器是将隐空间的样本作为其输入,并生成与训练集中数据相类似的数据的模型
# 二维输入(接收随机点(z1,z2))和一个二维输出,输出(x1_hat,x2_hat)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model=nn.Sequential(
nn.Linear(2,16),
nn.ReLU(),
nn.Linear(16,32),
nn.ReLU(),
nn.Linear(32,2) # 生成任何值
)
def forward(self,x):
output=self.model(x)
return output
discriminator=Discriminator() # 鉴别器网络模型
generator=Generator() # 生成器网络模型
训练两个模型
每次先用真实数据和生成数据训练判别器模型(让判别器尽可能正确区分真伪),然后将判别器固定,用生成数据训练生成器(尽可能让鉴别器都判断为真,即“以假乱真”)
# 训练模型
lr=0.001
num_epochs=300
loss_function=nn.BCELoss() # 二元交叉熵损失函数
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr) # 判别器优化器
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr) # 生成器优化器
for epoch in range(num_epochs):
for n,(real_samples,_)in enumerate(train_loader):
# 每一次给出一个batch,直到所有样本都用完
# 训练判别器的数据(生成器是固定的):
real_samples_labels=torch.ones((batch_size,1))
latent_space_samples=torch.randn((batch_size,2)) # 隐藏空间的样本batch
generated_samples=generator(latent_space_samples) # 生成器生成的样本batch
generated_samples_labels=torch.zeros((batch_size,1))
all_samples=torch.cat((real_samples,generated_samples)) # 所有的样本
all_samples_labels=torch.cat((real_samples_labels,generated_samples_labels)) # 1 0
# 训练判别器
optimizer_discriminator.zero_grad()
output_discriminator=discriminator(all_samples) # 鉴别器输出(输入了包括真实样本的所有样本)
loss_discriminator=loss_function(output_discriminator,all_samples_labels) # 通过结果看是否能够正确区分01
loss_discriminator.backward(retain_graph=True) # 不释放图中的中间变量(需要多次反向传播)
optimizer_discriminator.step()
# 训练生成器的数据(鉴别器是固定的)
latent_space_samples=torch.randn((batch_size,2))
# 训练生成器
optimizer_generator.zero_grad()
generator_samples=generator(latent_space_samples)
output_discriminator_generated=discriminator(generated_samples) # 只鉴别生成的
loss_generator=loss_function(output_discriminator_generated,real_samples_labels) # 训练生成器以使鉴别器的输出都是1
loss_generator.backward(retain_graph=True)
optimizer_generator.step()
if epoch % 10 ==0 and n==batch_size-1:
print(f"Epoch:{epoch} Loss D:{loss_discriminator}")
print(f"Epoch:{epoch} Loss G:{loss_generator}")
生成器模型输出
latend_space_samples=torch.randn(100,2)
generated_samples=generator(latend_space_samples) # 模型生成的输出
enerated_samples=generated_samples.detach() # 将张量从当前计算图中分离出来,不参与梯度计算,通常用于模型输出。
plt.plot(enerated_samples[:,0],enerated_samples[:,1],".")
![](https://img-blog.csdnimg.cn/direct/62d7391a6cbb4ba49e9b1da040f09d9c.png)
参考文献
[1] 汪美琴,袁伟伟,张继业.生成对抗网络GAN的研究综述[J].计算机工程与设计,2021,42(12):3389-3395.DOI:10.16208/j.issn1000-7024.2021.12.012.
[2] 程显毅,谢璐,朱建新,等.生成对抗网络GAN综述[J].计算机科学,2019,46(03):74-81.