GAN生成对抗网络

生成对抗网络(GAN)基本原理:

生成对抗网络(GAN)是由 Ian Goodfellow 等人于2014年提出的一种深度学习模型,用于生成逼真的数据样本。GAN 由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。这两个网络相互对抗,通过训练使生成器能够生成越来越逼真的数据,而判别器则变得越来越擅长区分真实数据和生成器生成的数据。

视频推荐:什么是 GAN(生成对抗网络)?【知多少】

基本原理:

  1. 生成器(Generator): 接收一个随机噪声向量作为输入,通过神经网络生成与真实数据相似的合成数据。生成器的目标是欺骗判别器,使其无法区分生成的数据和真实数据。

  2. 判别器(Discriminator): 接收真实数据和生成器生成的数据,并尝试区分哪些是真实的,哪些是生成的。判别器的目标是准确判断输入数据的来源。

  3. 对抗训练: 在训练过程中,生成器和判别器相互对抗。生成器希望生成的数据足够逼真,以欺骗判别器,而判别器希望更好地区分真实数据和生成的数据。这种对抗训练的过程最终使生成器生成逼真的数据。

数学公式:

GAN 的训练目标可以通过最小化一个损失函数来表示。生成器和判别器的损失函数分别为:

  1. 生成器损失:
    min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  2. 判别器损失:
    min ⁡ D max ⁡ G V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_D \max_G V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] DminGmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, p data ( x ) p_{\text{data}}(x) pdata(x)是真实数据的分布, p z ( z ) p_z(z) pz(z) 是生成器输入噪声的分布。

Python 代码示例:

以下是一个简单的使用 PyTorch 实现的 GAN 示例,

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 定义生成器
class GeneratorCNN(nn.Module):
    def __init__(self, input_size, output_size):
        super(GeneratorCNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 256 * 7 * 7)
        self.bn1 = nn.BatchNorm2d(256)
        self.conv_trans1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv_trans2 = nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 256, 7, 7)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.conv_trans1(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = self.conv_trans2(x)
        x = self.tanh(x)
        return x

# 定义判别器
class DiscriminatorCNN(nn.Module):
    def __init__(self):
        super(DiscriminatorCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128 * 7 * 7, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.leaky_relu(x, 0.2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.leaky_relu(x, 0.2)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc1(x)
        x = self.sigmoid(x)
        return x

if __name__ == '__main__':
    # 设置超参数
    batch_size = 64
    input_size = 100
    output_size = 1  # 判别器的输出大小为1,表示真假

    # 加载 MNIST 数据集
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # 初始化生成器、判别器和优化器
    generator = GeneratorCNN(input_size, output_size).to(device)
    discriminator = DiscriminatorCNN().to(device)
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
    criterion = nn.BCELoss()

    # 训练 GAN
    num_epochs = 10
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(data_loader):
            real_images = real_images.to(device)
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # 训练判别器
            optimizer_D.zero_grad()
            noise = torch.randn(batch_size, input_size, device=device)
            fake_images = generator(noise)
            output_real = discriminator(real_images)
            output_fake = discriminator(fake_images.detach())  # 使用detach()防止梯度传播到生成器
            loss_real = criterion(output_real, real_labels)
            loss_fake = criterion(output_fake, fake_labels)
            loss_D = loss_real + loss_fake
            loss_D.backward()
            optimizer_D.step()

            # 训练生成器
            optimizer_G.zero_grad()
            output_fake = discriminator(fake_images)
            loss_G = criterion(output_fake, real_labels)
            loss_G.backward()
            optimizer_G.step()

            # 打印训练信息
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(data_loader)}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')

    # 生成器生成图像
    with torch.no_grad():
        noise = torch.randn(16, input_size, device=device)
        generated_images = generator(noise).detach().reshape(-1, 28, 28).cpu().numpy()

    # 可视化生成的图像
    plt.figure(figsize=(4, 4))
    for i in range(generated_images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i], cmap='gray')
        plt.axis('off')
    plt.show()

Using device: cuda
Epoch [0/10], Batch [0/937], Loss D: 1.6401, Loss G: 0.8726
Epoch [0/10], Batch [100/937], Loss D: 0.0038, Loss G: 6.5093
Epoch [0/10], Batch [200/937], Loss D: 0.0026, Loss G: 7.3444
Epoch [0/10], Batch [300/937], Loss D: 0.0040, Loss G: 6.1462
Epoch [0/10], Batch [400/937], Loss D: 0.0011, Loss G: 7.5719
Epoch [0/10], Batch [500/937], Loss D: 0.0006, Loss G: 8.2353
Epoch [0/10], Batch [600/937], Loss D: 0.0010, Loss G: 8.7569
Epoch [0/10], Batch [700/937], Loss D: 0.0066, Loss G: 5.4503
Epoch [0/10], Batch [800/937], Loss D: 0.0017, Loss G: 16.3379
Epoch [0/10], Batch [900/937], Loss D: 0.0017, Loss G: 6.9244
Epoch [1/10], Batch [0/937], Loss D: 0.0009, Loss G: 7.9770
Epoch [1/10], Batch [100/937], Loss D: 0.0010, Loss G: 14.9311
Epoch [1/10], Batch [200/937], Loss D: 0.0025, Loss G: 9.1577
Epoch [1/10], Batch [300/937], Loss D: 0.0027, Loss G: 9.3959
Epoch [1/10], Batch [400/937], Loss D: 0.0029, Loss G: 8.8815
Epoch [1/10], Batch [500/937], Loss D: 0.0040, Loss G: 6.0449
Epoch [1/10], Batch [600/937], Loss D: 0.0011, Loss G: 8.1859
Epoch [1/10], Batch [700/937], Loss D: 0.0005, Loss G: 8.5195
Epoch [1/10], Batch [800/937], Loss D: 0.0010, Loss G: 12.9067
Epoch [1/10], Batch [900/937], Loss D: 0.0048, Loss G: 6.8532
Epoch [2/10], Batch [0/937], Loss D: 0.0015, Loss G: 8.1116
Epoch [2/10], Batch [100/937], Loss D: 0.0093, Loss G: 5.7263
Epoch [2/10], Batch [200/937], Loss D: 0.0026, Loss G: 7.5310
Epoch [2/10], Batch [300/937], Loss D: 0.0093, Loss G: 6.2506
Epoch [2/10], Batch [400/937], Loss D: 0.0136, Loss G: 5.7398
Epoch [2/10], Batch [500/937], Loss D: 0.0367, Loss G: 4.2857
Epoch [2/10], Batch [600/937], Loss D: 0.0212, Loss G: 5.2088
Epoch [2/10], Batch [700/937], Loss D: 0.0279, Loss G: 6.0875
Epoch [2/10], Batch [800/937], Loss D: 0.0719, Loss G: 5.4136
Epoch [2/10], Batch [900/937], Loss D: 0.0313, Loss G: 4.8517
Epoch [3/10], Batch [0/937], Loss D: 0.0193, Loss G: 6.7583
Epoch [3/10], Batch [100/937], Loss D: 0.1850, Loss G: 3.9395
Epoch [3/10], Batch [200/937], Loss D: 0.1135, Loss G: 4.4806
Epoch [3/10], Batch [300/937], Loss D: 0.4871, Loss G: 3.3336
Epoch [3/10], Batch [400/937], Loss D: 0.1662, Loss G: 2.6347
Epoch [3/10], Batch [500/937], Loss D: 0.2953, Loss G: 3.7447
Epoch [3/10], Batch [600/937], Loss D: 0.3024, Loss G: 2.5100
Epoch [3/10], Batch [700/937], Loss D: 0.3068, Loss G: 2.8986
Epoch [3/10], Batch [800/937], Loss D: 0.1014, Loss G: 3.3972
Epoch [3/10], Batch [900/937], Loss D: 0.1630, Loss G: 4.3400
Epoch [4/10], Batch [0/937], Loss D: 0.1616, Loss G: 3.5032
Epoch [4/10], Batch [100/937], Loss D: 0.4845, Loss G: 2.4996
Epoch [4/10], Batch [200/937], Loss D: 0.1928, Loss G: 3.0287
Epoch [4/10], Batch [300/937], Loss D: 0.2848, Loss G: 2.2172
Epoch [4/10], Batch [400/937], Loss D: 0.3643, Loss G: 2.5848
Epoch [4/10], Batch [500/937], Loss D: 0.1544, Loss G: 3.2025
Epoch [4/10], Batch [600/937], Loss D: 0.1988, Loss G: 2.8276
Epoch [4/10], Batch [700/937], Loss D: 0.1579, Loss G: 3.8011
Epoch [4/10], Batch [800/937], Loss D: 0.2830, Loss G: 2.7761
Epoch [4/10], Batch [900/937], Loss D: 0.2649, Loss G: 2.4664
Epoch [5/10], Batch [0/937], Loss D: 0.2667, Loss G: 3.0192
Epoch [5/10], Batch [100/937], Loss D: 0.2094, Loss G: 2.9665
Epoch [5/10], Batch [200/937], Loss D: 0.1918, Loss G: 3.5549
Epoch [5/10], Batch [300/937], Loss D: 0.1342, Loss G: 2.7686
Epoch [5/10], Batch [400/937], Loss D: 0.1424, Loss G: 3.3407
Epoch [5/10], Batch [500/937], Loss D: 0.2777, Loss G: 2.8267
Epoch [5/10], Batch [600/937], Loss D: 0.2675, Loss G: 4.7375
Epoch [5/10], Batch [700/937], Loss D: 0.2871, Loss G: 3.1842
Epoch [5/10], Batch [800/937], Loss D: 0.2563, Loss G: 2.7062
Epoch [5/10], Batch [900/937], Loss D: 0.2027, Loss G: 3.3090
Epoch [6/10], Batch [0/937], Loss D: 0.2655, Loss G: 2.6006
Epoch [6/10], Batch [100/937], Loss D: 0.2564, Loss G: 2.7532
Epoch [6/10], Batch [200/937], Loss D: 0.2245, Loss G: 3.0649
Epoch [6/10], Batch [300/937], Loss D: 0.2071, Loss G: 2.4969
Epoch [6/10], Batch [400/937], Loss D: 0.2836, Loss G: 3.0069
Epoch [6/10], Batch [500/937], Loss D: 0.2760, Loss G: 2.7501
Epoch [6/10], Batch [600/937], Loss D: 0.4911, Loss G: 4.4188
Epoch [6/10], Batch [700/937], Loss D: 0.2641, Loss G: 3.7137
Epoch [6/10], Batch [800/937], Loss D: 0.3598, Loss G: 2.8660
Epoch [6/10], Batch [900/937], Loss D: 0.2570, Loss G: 2.8966
Epoch [7/10], Batch [0/937], Loss D: 0.3485, Loss G: 2.5611
Epoch [7/10], Batch [100/937], Loss D: 0.1675, Loss G: 2.6982
Epoch [7/10], Batch [200/937], Loss D: 0.2223, Loss G: 2.4345
Epoch [7/10], Batch [300/937], Loss D: 0.3087, Loss G: 2.4485
Epoch [7/10], Batch [400/937], Loss D: 0.2545, Loss G: 2.7256
Epoch [7/10], Batch [500/937], Loss D: 0.2279, Loss G: 2.9668
Epoch [7/10], Batch [600/937], Loss D: 0.2359, Loss G: 2.9180
Epoch [7/10], Batch [700/937], Loss D: 0.2313, Loss G: 2.8279
Epoch [7/10], Batch [800/937], Loss D: 0.1824, Loss G: 5.1954
Epoch [7/10], Batch [900/937], Loss D: 0.2634, Loss G: 3.4008
Epoch [8/10], Batch [0/937], Loss D: 0.2653, Loss G: 2.3400
Epoch [8/10], Batch [100/937], Loss D: 0.2266, Loss G: 2.8127
Epoch [8/10], Batch [200/937], Loss D: 0.2408, Loss G: 2.6616
Epoch [8/10], Batch [300/937], Loss D: 0.3541, Loss G: 1.7452
Epoch [8/10], Batch [400/937], Loss D: 0.2817, Loss G: 3.2503
Epoch [8/10], Batch [500/937], Loss D: 0.2609, Loss G: 2.7807
Epoch [8/10], Batch [600/937], Loss D: 0.1834, Loss G: 3.2357
Epoch [8/10], Batch [700/937], Loss D: 0.3784, Loss G: 2.9750
Epoch [8/10], Batch [800/937], Loss D: 0.2408, Loss G: 2.8422
Epoch [8/10], Batch [900/937], Loss D: 0.2076, Loss G: 2.6136
Epoch [9/10], Batch [0/937], Loss D: 0.3452, Loss G: 2.8782
Epoch [9/10], Batch [100/937], Loss D: 0.4407, Loss G: 3.3718
Epoch [9/10], Batch [200/937], Loss D: 0.2751, Loss G: 2.9589
Epoch [9/10], Batch [300/937], Loss D: 0.3030, Loss G: 3.0577
Epoch [9/10], Batch [400/937], Loss D: 0.1908, Loss G: 3.5729
Epoch [9/10], Batch [500/937], Loss D: 0.2850, Loss G: 1.8752
Epoch [9/10], Batch [600/937], Loss D: 0.3251, Loss G: 3.0342
Epoch [9/10], Batch [700/937], Loss D: 0.5650, Loss G: 2.9326
Epoch [9/10], Batch [800/937], Loss D: 0.3391, Loss G: 2.0852
Epoch [9/10], Batch [900/937], Loss D: 0.1692, Loss G: 2.6809

在这里插入图片描述
训练100轮后的结果:
在这里插入图片描述

  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值