Gan 对抗生成网络

两个重要的东西,生成器,判别器。

以图片生成举例。生成器顾名思义就是生成图片,判别器就是判别这个图片是真的还是假的。

结构就出来了,首先要有一个网络来进行图片生成,还要有一个网络对正确图片进行训练判别,还有一个对假的模拟生成图片进行判别。这两个网络最后的结果是一个0-1的值是一个概率。这样不断调整,当这个判断越来越准确这个图片生成的就会越来越逼真,这样这几个网络的结果相互影响最终这个识别图片会很精准,图片生成也会很逼真。

看源码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_size, hidden_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(hidden_size * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size, output_size, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_size, hidden_size, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size, hidden_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 2, hidden_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 4, hidden_size * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# 超参数设置
latent_size = 100
hidden_size = 128
image_size = 64
channels = 1
batch_size = 128
num_epochs = 200
learning_rate = 0.0002
beta1 = 0.5

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 实例化生成器和判别器模型
G = Generator(latent_size, hidden_size, channels).to(device)
D = Discriminator(channels, hidden_size).to(device)

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# 开始训练
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        # 标签定义
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 训练判别器
        D.zero_grad()
        images = images.to(device)
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()

        z = torch.randn(batch_size, latent_size, 1, 1).to(device)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        optimizer_D.step()

        # 训练生成器
        G.zero_grad()
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

        if (i + 1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], '
                  f'D Loss: {d_loss_real.item() + d_loss_fake.item():.4f}, G Loss: {g_loss.item():.4f}')

    # 保存生成的图像
    if (epoch + 1) % 10 == 0:
        save_image(fake_images, f'./samples/fake_images-{epoch+1}.png')

# 生成并保存最终的图像
z = torch.randn(batch_size, latent_size, 1, 1).to(device)
fake_images = G(z)
save_image(fake_images, './samples/fake_images_final.png')

注意这个D.zero_grad() images = images.to(device) outputs = D(images) d_loss_real = criterion(outputs, real_labels) d_loss_real.backward() z = torch.randn(batch_size, latent_size, 1, 1).to(device) fake_images = G(z) outputs = D(fake_images.detach()) d_loss_fake = criterion(outputs, fake_labels) d_loss_fake.backward() optimizer_D.step()

这是训练判别器的代码,这一部分是训练了两个数据,一个是真实的图片一个是高斯分布模拟出来的图片。因为这个模型要有对真实图片和假的图片分辨能力。

nn.ConvTranspose2d这个玩意是反卷积层(转置卷积/上采样卷积),把输入的小的特征图转化为大的特征图,说白了就是根据特征图生成含有这些特征的模拟图片,剩下的代码就很简单不做解释

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值