GAN 生成式对抗网络介绍

当谈到生成式对抗网络(GANs)时,一个常见的例子是图像生成。让我们以生成手写数字图像为例来详细说明GANs的工作原理。

  1. 定义网络结构

    • 生成器(Generator):接收一个随机噪声向量作为输入,通常是一个低维的随机向量(例如100维),然后通过一系列反卷积层(deconvolutional layers)将这个随机噪声向量映射为一张与真实手写数字图像相似的图像。
    • 判别器(Discriminator):接收一张图像作为输入,然后通过一系列卷积层(convolutional layers)将其映射为一个概率,表示这张图像是真实手写数字图像的概率。
  2. 初始化网络参数:生成器和判别器的权重参数需要进行初始化,可以使用随机初始化的方式。

  3. 定义损失函数

    • 生成器的损失函数:生成器的目标是生成逼真的手写数字图像,因此其损失函数通常是生成的图像被判别器判别为真实图像的概率的负对数似然。
    • 判别器的损失函数:判别器的目标是正确地区分真实手写数字图像和生成器生成的合成手写数字图像,因此其损失函数通常是真实图像被判别为真实图像的概率与生成图像被判别为真实图像的概率之间的差异。
  4. 训练过程

    • 生成器首先生成一批合成手写数字图像,然后这批合成图像与来自真实手写数字图像数据集的一批真实图像一起输入到判别器中。
    • 判别器对这些图像进行判断,并计算生成器和判别器的损失函数。
    • 根据损失函数的值更新生成器和判别器的参数。
    • 重复上述步骤,直到生成器生成的手写数字图像足够逼真或者损失函数收敛。
  5. 评估生成器:训练完成后,可以使用生成器生成手写数字图像样本,并通过人工或者其他评价指标来评估生成的图像的质量。

  6. 应用生成器:训练好的生成器可以应用于各种任务,例如生成手写数字图像的样本、图像修复等。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# 定义生成器网络
class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, output_size),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.fc(x)

# 定义判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.fc(x)

# 定义训练参数
batch_size = 64
input_size = 100
output_size = 784  # 对应28x28的手写数字图像

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 创建生成器和判别器实例
G = Generator(input_size, output_size)
D = Discriminator(output_size)

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)

# 定义训练循环
num_epochs = 100
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        
        # 训练判别器
        D.zero_grad()
        z = torch.randn(batch_size, input_size)
        fake_images = G(z)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        real_outputs = D(real_images.view(batch_size, -1))
        fake_outputs = D(fake_images.detach().view(batch_size, -1))
        d_loss = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels)
        d_loss.backward()
        optimizer_D.step()
        
        # 训练生成器
        G.zero_grad()
        fake_outputs = D(fake_images.view(batch_size, -1))
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()
        
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch Step [{batch_idx}/{len(train_loader)}], \
                  D_Loss: {d_loss.item():.4f}, G_Loss: {g_loss.item():.4f}")

# 生成手写数字图像样本
num_samples = 10
z = torch.randn(num_samples, input_size)
fake_images = G(z)

这段代码是一个基本的生成对抗网络(GANs)的实现示例,用于生成手写数字图像。下面是对代码的详细解释:

  1. 导入库:

    • import torch: 导入PyTorch库,用于构建深度学习模型和进行张量计算。
    • import torch.nn as nn: 导入PyTorch的神经网络模块,用于构建神经网络层。
    • import torch.optim as optim: 导入PyTorch的优化器模块,用于定义优化器。
    • from torchvision import datasets, transforms: 导入torchvision库,其中包含常用的数据集和数据转换操作。
    • from torch.utils.data import DataLoader: 导入PyTorch的数据加载模块,用于加载训练数据。
    • import numpy as np: 导入NumPy库,用于数值计算。
  2. 定义生成器(Generator)和判别器(Discriminator)的网络结构:

    • Generator类:定义了生成器的网络结构,包括一个全连接层序列,接受输入大小为input_size,输出大小为output_size
    • Discriminator类:定义了判别器的网络结构,也包括一个全连接层序列,接受输入大小为input_size,输出大小为1(用于二分类)。
  3. 定义训练参数和加载数据集:

    • batch_size:每个训练批次的样本数量。
    • input_size:生成器网络输入的随机噪声向量大小。
    • output_size:生成器网络输出的图像大小(在这个例子中是28x28=784)。
    • 使用transforms.Compose定义了数据预处理的操作,包括将图像转换为张量和归一化。
    • 加载MNIST数据集,并定义了训练数据集的数据加载器。
  4. 创建生成器和判别器实例:

    • 实例化了GeneratorDiscriminator类,分别得到生成器G和判别器D的对象。
  5. 定义损失函数和优化器:

    • 使用二元交叉熵损失函数nn.BCELoss(),用于计算生成器和判别器的损失。
    • 分别为生成器和判别器定义了Adam优化器,并传入相应的参数。
  6. 定义训练循环:

    • num_epochs定义了训练的总轮数。
    • 使用双重循环进行训练,外层循环遍历每个epoch,内层循环遍历每个批次的数据。
    • 在每个批次中,首先训练判别器:随机生成噪声数据z,通过生成器生成假图像,然后计算判别器的损失,并更新判别器的参数。
    • 然后训练生成器:再次生成假图像,通过判别器判断,并计算生成器的损失,并更新生成器的参数。
    • 在每个epoch的指定步骤(batch_idx % 100 == 0)打印判别器和生成器的损失值。
  7. 生成手写数字图像样本:

    • 使用生成器G生成指定数量的手写数字图像样本,通过给定的随机噪声z作为输入。

这样,整个代码就是一个简单的GANs实现示例,用于生成手写数字图像。

  • 15
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值