Conditional GAN代码实现(Pytorch)

11 篇文章 2 订阅
1 篇文章 0 订阅

论文地址:
https://arxiv.org/abs/1411.1784

1. 提出的背景

  1. 传统的GAN虽然可以生成图像,但是无法控制具体生成图像的种类。例如在生成手写体时,GAN和DCGAN都可以生成0-9这十个数字,但是用户无法指定具体生成那个数字的图像;
  2. GAN和DCGAN存在模式崩塌现象(Mode collapse(模式坍塌))。

2. 主要思想

  • GAN主要包括两个网络,一个生成器和一个判别器,GAN的主要优化函数是

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p d a t a ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \begin{array}{c} \min _{G} \max _{D} V(D, G)=\mathbb{E} x \sim p d a t a(x)[\log D(x)] \\ +\mathbb{E} z \sim p d a t a(z)[\log (1-D(G(z)))] \end{array} minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpdata(z)[log(1D(G(z)))]
CGAN与传统的GAN相比,区别就是增加了标签作为训练的一个输入,CGAN的优化函数为
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p  data  ( x ) [ log ⁡ D ( x ∣ y ) ] + E z ∼ p  data  ( z ) [ log ⁡ ( 1 − D ( G ( z ∣ y ) ) ) ] \min _{G} \max _{D} V(D, G)=\mathbb{E}_{x \sim p \text { data }(x)}[\log D(x \mid y)]+\mathbb{E}_{z \sim p \text { data }(z)}[\log (1-D(G(z \mid y)))] minGmaxDV(D,G)=Exp data (x)[logD(xy)]+Ezp data (z)[log(1D(G(zy)))]

  • 结构图:

3. 具体实现

  1. 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 两个128 x 7 x 7 cat后依然为256 x 7 x 7
        self.linear1 = nn.Sequential(
            nn.Linear(100, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.linear2 = nn.Sequential(
            nn.Linear(10, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )

        self.model = nn.Sequential(
            # 128 x 7 x 7
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 64 x 14 x 14
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 1 x 28 x 28
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, c):
        x = self.linear1(x)
        x = x.view(-1, 128, 7, 7)
        c = self.linear2(c)
        c = c.view(-1, 128, 7, 7)
        # 256 x 7 x 7
        # 在channels方面合并
        x = torch.cat([x, c], dim=1)
        return self.model(x)

判别器接收两个输入,一个是随机噪声,一个是标签,将噪声和标签转换为长度128x7x7的向量,再将两个向量连接起来,构成一个256x7x7的向量,再进行三次的转置卷积,最终输出一个1x28x28(与mnist数据集的大小保持一致)的图像。

  1. 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # input: 1 x 28 x 28 + 10 condition
        self.linear = nn.Sequential(
            nn.Linear(10, 1 * 28 * 28),
            nn.ReLU()
        )
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.BatchNorm2d(128),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 6 * 6, 1),
            nn.Sigmoid()
        )

    def forward(self, x, c):
        c = self.linear(c)
        c = c.view(-1, 1, 28, 28)
        # 2 x 28 x 28
        x = torch.cat([x, c], dim=1)
        x = self.model(x)
        x = x.view(-1, 128 * 6 * 6)
        x = self.fc(x)
        return x

判别器也是接收两个参数,一个是图像(可能是真实图像,也可能是生成的虚假的图像),另一个是标签,首先将标签转换为1x28x28的形状,然后将这个向量和图像连接起来,构成一个2x28x28的向量,最后经过卷积、激活、池化、线形层输出一个结果(真或者假)。

  1. 训练
  • 训练判别器

判别器要尽可能地区分出真实图片和虚假的图片;
将真实的图像和标签放入到判别器中,计算判别器输出与1之间的损失;
根据噪声生成虚假的图片,将虚假的图片和标签放入到判别器中,计算判别器输出和0之间的损失;
反向传播、迭代优化。

  • 训练生成器

生成器要尽可能的使生成的图像接近真实的图像,让判别器无法判断出图片的来源(真实还是生成);
将生成的虚假的图片放入到判别器中,计算判别器的输出与1之间的损失;
反向传播、迭代优化。

4. 代码

import torch
from torch import nn, cuda
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
import numpy as np
from tqdm import tqdm
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def one_hot(x, class_count=10):
    return torch.eye(class_count)[x, :]


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 两个128 x 7 x 7 cat后依然为256 x 7 x 7
        self.linear1 = nn.Sequential(
            nn.Linear(100, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.linear2 = nn.Sequential(
            nn.Linear(10, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )

        self.model = nn.Sequential(
            # 128 x 7 x 7
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 64 x 14 x 14
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 1 x 28 x 28
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, c):
        x = self.linear1(x)
        x = x.view(-1, 128, 7, 7)
        c = self.linear2(c)
        c = c.view(-1, 128, 7, 7)
        # 256 x 7 x 7
        # 在channels方面合并
        x = torch.cat([x, c], dim=1)
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # input: 1 x 28 x 28 + 10 condition
        self.linear = nn.Sequential(
            nn.Linear(10, 1 * 28 * 28),
            nn.ReLU()
        )
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.BatchNorm2d(128),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 6 * 6, 1),
            nn.Sigmoid()
        )

    def forward(self, x, c):
        c = self.linear(c)
        c = c.view(-1, 1, 28, 28)
        # 2 x 28 x 28
        x = torch.cat([x, c], dim=1)
        x = self.model(x)
        x = x.view(-1, 128 * 6 * 6)
        x = self.fc(x)
        return x


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

dataset = torchvision.datasets.MNIST("./data", train=True,
                                     transform=transform,
                                     download=True,
                                     target_transform=one_hot)

dataloader = data.DataLoader(dataset, batch_size=512, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator().to(device)
disc = Discriminator().to(device)
loss_fn = torch.nn.BCELoss()
opt_g = torch.optim.RMSprop(gen.parameters(), lr=0.0001)
opt_d = torch.optim.Adam(disc.parameters(), lr=0.0001)
num_epochs = 201

writer_g = SummaryWriter("/root/tf-logs/g")
writer_d = SummaryWriter("/root/tf-logs/d")

noise_seed = torch.randn(16, 100, device=device)
# 16个0-10之间的随机整数
label_seed = torch.randint(0, 10, size=(16,))
print(f"label seed: {label_seed}")
print(type(label_seed))
label_seed_onehot = one_hot(label_seed).to(device)
print(f"label_seed: {label_seed}")

for epoch in range(num_epochs):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataset)
    loop = tqdm(dataloader, leave=True, desc=f"Epoch: {epoch}/{num_epochs}")
    for step, (img, label) in enumerate(loop):
        img = img.to(device)
        label = label.to(device)
        size = img.shape[0]
        random_seed = torch.randn(size, 100, device=device)

        # 训练判别器
        opt_d.zero_grad()
        # 真实图片放入判别器中
        real_output = disc(img, label)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))
        # 生成图像并放入判别器中
        gen_img = gen(random_seed, label)
        fake_output = disc(gen_img.detach(), label)
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        opt_d.step()

        # 训练生成器
        opt_g.zero_grad()
        fake_output = disc(gen_img, label)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))
        g_loss.backward()
        opt_g.step()

        with torch.no_grad():
            D_epoch_loss += d_loss.item()
            G_epoch_loss += g_loss.item()
            loop.set_postfix(G_loss=f"{np.round(G_epoch_loss, 2)}", D_loss=f"{np.round(D_epoch_loss, 2)}")

    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        writer_g.add_scalar("loss", G_epoch_loss, epoch)
        writer_d.add_scalar("loss", D_epoch_loss, epoch)

        if epoch % 20 == 0:
            with torch.no_grad():
                gen_img = gen(noise_seed, label_seed_onehot)
                writer_g.add_images("gen mnist", gen_img, epoch)

torch.save(gen.state_dict(), "./gen.pth")

5. 训练结果

5. 参考资料

  1. https://arxiv.org/abs/1411.1784
  2. https://blog.csdn.net/qq_41647438/article/details/103007057
  3. https://blog.csdn.net/xjp_xujiping/article/details/102719363
  4. https://zhuanlan.zhihu.com/p/510346635
  5. https://www.jianshu.com/p/39c57e9a6630
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
条件GANConditional GANs)是一种生成对抗网络(GAN)的变体,它通过给生成器和鉴别器提供额外的条件信息来改进生成过程。在PyTorch实现条件GANs时,需要对生成器和判别器的结构进行修改。 引用中给出了生成器的具体实现。生成器接收两个输入:一个是噪声向量x,一个是条件向量c。首先,将x通过线性层进行处理,得到一个大小为128x7x7的张量。然后,将x和c通过cat操作在channels方向上进行合并,形成一个大小为256x7x7的张量。最后,通过三次转置卷积操作将张量的尺寸逐渐放大,最终生成一个大小为1x28x28的图像。 引用中给出了判别器的具体实现。判别器接收两个输入:一个是真实图像x,一个是条件向量c。首先,将c通过线性层进行处理,得到一个大小为1x28x28的张量。然后,将x和c通过cat操作在channels方向上进行合并,形成一个大小为2x28x28的张量。接下来,通过卷积层、LeakyReLU激活函数和Dropout层对张量进行处理。最后,将张量展平后通过全连接层得到一个概率值,表示输入图像为真实图像的概率。 通过以上改进,条件GANs可以在生成过程中根据给定的条件生成特定的图像。这种结构可以应用于各种任务,如图像生成、图像修复和图像转换等。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Conditional GAN代码实现Pytorch)](https://blog.csdn.net/weixin_40330033/article/details/127212518)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [pytorch-GANs:我对各种GAN(生成对抗网络)架构的实现,例如香草GAN(Goodfellow等),cGAN(Mirza等),...](https://download.csdn.net/download/weixin_42116701/15910571)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值