生成对抗网络介绍
生成对抗网络(Generative Adversarial Network,简称GAN)是一种深度学习模型,由Ian Goodfellow等人于2014年提出。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。GAN的目标是通过两个网络之间的对抗学习来生成逼真的数据。
- 生成器(Generator): 生成器是一个神经网络,它接收一个随机噪声向量作为输入,并试图将这个随机噪声转换为逼真的数据样本。在训练过程中,生成器不断试图提高生成样本的质量,使其能够欺骗判别器。初始阶段生成的样本可能不够真实,但随着训练的进行,生成器逐渐学会生成更加逼真的数据样本。
- 判别器(Discriminator): 判别器也是一个神经网络,它的任务是区分真实数据样本和由生成器生成的假样本。它类似于一个二分类器,努力将输入样本分为“真实”和“假”的两个类别。在训练过程中,判别器通过不断学习区分真实样本和生成样本,使得判别器的准确率不断提高。
GAN的训练过程是一个对抗过程:
- 生成器通过将随机噪声转换为生成样本,并将这些生成样本传递给判别器。
- 判别器根据传递给它的真实样本和生成样本对其进行分类,并输出相应的概率分数。
- 根据判别器的输出,生成器试图生成能够欺骗判别器的更逼真的样本。
- 这个过程不断重复,直到生成器生成的样本足够逼真,判别器无法准确区分真假样本。
通过这种对抗学习的过程,GAN能够生成高质量的数据样本,广泛应用于图像、音频、文本等领域。然而,训练GAN也存在一些挑战,如训练不稳定、模式崩溃等问题,需要经验丰富的研究人员进行调优和改进。
MNIST—GAN
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 定义生成器和判别器的类
class Generator(nn.Module):
def __init__(self, z_dim=100, hidden_dim=128, output_dim=784):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.LeakyReLU(0.01),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.LeakyReLU(0.01),
nn.Linear(hidden_dim * 2, output_dim),
nn.Tanh()
)
def forward(self, noise):
return self.gen(noise)
class Discriminator(nn.Module):
def __init__(self, input_dim=784, hidden_dim=128):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
nn.Linear(input_dim, hidden_dim * 2),
nn.LeakyReLU(0.01),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LeakyReLU(0.01),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, image):
return self.disc(image)
# 定义训练函数
def train_gan(generator, discriminator, dataloader, num_epochs=50, z_dim=100, lr=0.0002):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
gen_optim = optim.Adam(generator.parameters(), lr=lr)
disc_optim