1、引言
小屌丝:鱼哥,忙吗?
小鱼:忙
小屌丝:哦,那我约别人去吧
小鱼:… 去哪?
小屌丝:反正你也很忙, 就不跟你说了,
小鱼:… 你说说看去哪?
小屌丝:不说了不说了, 说了耽误你工作
小鱼:那我不忙
小屌丝:确定不忙啊
小鱼:必须得,不忙, 你说去哪?
小屌丝: 不忙的话, 给我讲一讲GAN呗
小鱼:…
小屌丝:那你可说好了?
小鱼:… 正人君子,休想套路我
小屌丝:那我可去喽?
小鱼: …唉~ 别这样啊, 我说GAN ,你带我去。
小屌丝:成交
2、生成对抗网络
2.1 定义
生成对抗网络(Generative Adversarial Networks,简称GAN)是深度学习领域中的一种生成模型,由Ian J. Goodfellow等人于2014年提出。
GAN的核心思想是通过让两个神经网络:
- 生成器(Generator)和
- 判别器(Discriminator)
进行对抗训练,以达到生成接近真实数据分布的人工样本的目的。
2.2 原理
GAN的原理基于一种“零和游戏”的博弈思想。
- 在这个游戏中,生成器试图生成逼真的假样本来欺骗判别器,而判别器则努力区分输入的样本是来自真实数据分布还是生成器生成的假样本。
- 两者通过相互竞争的方式进行训练,最终达到一种动态平衡状态,即纳什均衡(Nash Equilibrium)。
具体来说,生成器接受一个随机噪声作为输入,通过一系列的非线性变换生成一个输出样本。
判别器则接收一个输入样本(可能是真实样本或生成样本),并输出一个概率值,表示该样本是真实样本的可能性。
在训练过程中,生成器和判别器通过反向传播算法更新各自的参数,以最大化自己的损失函数。
当训练达到收敛时,生成器能够生成与真实数据分布难以区分的样本,而判别器对于任何输入样本的输出概率都接近0.5(即无法区分真假)。
此时,可以认为GAN已经学会了真实数据的分布。
2.3 实现方式
GAN的实现方式主要包括以下几个步骤:
- 定义网络结构:首先,需要定义生成器和判别器的网络结构。这两个网络通常采用深度神经网络(如卷积神经网络CNN)来实现。
- 初始化网络参数:在训练开始前,需要对生成器和判别器的参数进行初始化。通常使用随机初始化方法,如高斯分布初始化。
- 前向传播:对于每一批训练数据,将真实数据和随机噪声分别输入到判别器和生成器中,进行前向传播计算得到输出。
- 计算损失函数:根据判别器和生成器的输出,计算各自的损失函数。判别器的损失函数通常包括两部分:一部分是真实样本的损失(即真实样本被判为假的损失),另一部分是生成样本的损失(即生成样本被判为真的损失)。生成器的损失函数则是基于判别器对生成样本的判别结果来计算的。
- 反向传播:根据损失函数的梯度信息,使用反向传播算法更新生成器和判别器的参数。
- 迭代训练:重复以上步骤,直到达到预设的训练轮数或满足其他停止条件。
2.4 算法公式
GAN的核心是通过优化以下的min-max公式来训练生成器和判别器: [ min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] ] [ \min_G \max_D V(D, G) = \mathbb{E}{x\sim p{data}(x)}[\log D(x)] + \mathbb{E}{z\sim p{z}(z)}[\log(1 - D(G(z)))] ] [GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]]
其中, ( G ( z ) ) (G(z)) (G(z))表示生成器试图通过输入噪声 ( z ) (z) (z)生成的数据, ( D ( x ) ) (D(x)) (D(x))表示判别器对于给定输入 ( x ) (x) (x)的判断(即该数据是真实的概率)
2.5 代码示例
# -*- coding:utf-8 -*-
# @Time : 2024-01-21
# @Author : Carl_DJ
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 设定超参数
batch_size = 64
learning_rate = 0.0002
epochs = 50
latent_dim = 100 # 噪声向量的维度
# 数据预处理和加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 构建生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img
# 构建判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 损失函数
adversarial_loss = nn.BCELoss()
# 训练
for epoch in range(epochs):
for i, (imgs, _) in enumerate(train_loader):
# 真实数据和假数据的标签
real = torch.ones(imgs.size(0), 1)
fake = torch.zeros(imgs.size(0), 1)
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn(imgs.size(0), latent_dim)
generated_imgs = generator(z)
g_loss = adversarial_loss(discriminator(generated_imgs), real)
g_loss.backward()
optimizer_G.step()
# 训练判别器
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), real)
fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 打印进度
print(f"Epoch [{epoch+1}/{epochs}] Batch {i+1}/{len(train_loader)} Loss D: {d_loss.item()}, loss G: {g_loss.item()}")
代码解析
-
首先设置了一些超参数,如批大小、学习率、训练周期和潜在空间维度。
-
然后,它定义了两个关键的神经网络:生成器和判别器。
- 生成器旨在从随机噪声中生成手写数字图像,
- 而判别器则尝试区分真实图像和生成的图像。
- 通过交替训练这两个模型,GAN能够生成越来越逼真的手写数字图像。
-
在训练过程中,我们分别计算生成器和判别器的损失,并通过反向传播更新它们的权重。
-
随着训练的进行,生成器将变得越来越擅长生成逼真的图像,而判别器则会变得越来越擅长区分真假图像。
3、总结
生成对抗网络(GAN)作为深度学习领域的一种强大生成模型,已经在图像生成、图像修复、图像转换、文本生成等多个领域取得了显著成果。
其基本原理是通过让生成器和判别器进行对抗训练,以达到生成接近真实数据分布的人工样本的目的。
GAN的实现过程涉及到深度神经网络、优化算法、损失函数等多个方面的知识。
我是小鱼:
- CSDN 博客专家;
- 阿里云 专家博主;
- 51CTO博客专家;
- 企业认证金牌面试官;
- 多个名企认证&特邀讲师等;
- 名企签约职场面试培训、职场规划师;
- 多个国内主流技术社区的认证专家博主;
- 多款主流产品(阿里云等)测评一、二等奖获得者;
关注小鱼,学习【机器学习】&【深度学习】领域的知识。