深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

  🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 结合代码进一步了解CGAN
  2. 学习如何运用生成好的生成器生成指定图像

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1. 导入第三方库

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)
os.makedirs('./training_weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

得到如下输出: 

cuda

 2. 导入数据

batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="GAN-3-data", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

3. 数据可视化

运行下述代码:

def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出图像为:

4. 定义超参数 

运行下述代码:

latent_dim = 100
n_classes = 3
embedding_dim = 100

5. 构建模型

5.1.初始化权重

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

5.2.定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

 5.3.定义鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、 训练模型 

1. 定义训练参数

adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3. 训练模型

num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

Epoch: [1/100]: D_loss: 0.285, G_loss: 2.018
Epoch: [2/100]: D_loss: 0.331, G_loss: 2.298
Epoch: [3/100]: D_loss: 0.403, G_loss: 1.715
Epoch: [4/100]: D_loss: 0.467, G_loss: 1.416
Epoch: [5/100]: D_loss: 0.490, G_loss: 1.618
Epoch: [6/100]: D_loss: 0.490, G_loss: 1.585
Epoch: [7/100]: D_loss: 0.379, G_loss: 1.674
Epoch: [8/100]: D_loss: 0.443, G_loss: 1.889
Epoch: [9/100]: D_loss: 0.541, G_loss: 2.067
Epoch: [10/100]: D_loss: 0.565, G_loss: 1.751
Epoch: [11/100]: D_loss: 0.528, G_loss: 1.495
Epoch: [12/100]: D_loss: 0.555, G_loss: 1.461
Epoch: [13/100]: D_loss: 0.569, G_loss: 1.490
Epoch: [14/100]: D_loss: 0.531, G_loss: 1.498
Epoch: [15/100]: D_loss: 0.504, G_loss: 1.532
Epoch: [16/100]: D_loss: 0.487, G_loss: 1.612
Epoch: [17/100]: D_loss: 0.457, G_loss: 1.776
Epoch: [18/100]: D_loss: 0.462, G_loss: 1.767
Epoch: [19/100]: D_loss: 0.437, G_loss: 1.946
Epoch: [20/100]: D_loss: 0.446, G_loss: 1.848
Epoch: [21/100]: D_loss: 0.463, G_loss: 1.718
Epoch: [22/100]: D_loss: 0.473, G_loss: 1.748
Epoch: [23/100]: D_loss: 0.503, G_loss: 1.579
Epoch: [24/100]: D_loss: 0.482, G_loss: 1.410
Epoch: [25/100]: D_loss: 0.489, G_loss: 1.440
Epoch: [26/100]: D_loss: 0.494, G_loss: 1.425
Epoch: [27/100]: D_loss: 0.510, G_loss: 1.398
Epoch: [28/100]: D_loss: 0.475, G_loss: 1.410
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.459
Epoch: [30/100]: D_loss: 0.473, G_loss: 1.489
Epoch: [31/100]: D_loss: 0.462, G_loss: 1.484
Epoch: [32/100]: D_loss: 0.448, G_loss: 1.520
Epoch: [33/100]: D_loss: 0.457, G_loss: 1.548
Epoch: [34/100]: D_loss: 0.418, G_loss: 1.558
Epoch: [35/100]: D_loss: 0.433, G_loss: 1.667
Epoch: [36/100]: D_loss: 0.402, G_loss: 1.665
Epoch: [37/100]: D_loss: 0.401, G_loss: 1.709
Epoch: [38/100]: D_loss: 0.425, G_loss: 1.841
Epoch: [39/100]: D_loss: 0.399, G_loss: 1.711
Epoch: [40/100]: D_loss: 0.429, G_loss: 1.873
Epoch: [41/100]: D_loss: 0.374, G_loss: 1.857
Epoch: [42/100]: D_loss: 0.382, G_loss: 1.869
Epoch: [43/100]: D_loss: 0.431, G_loss: 1.935
Epoch: [44/100]: D_loss: 0.355, G_loss: 1.871
Epoch: [45/100]: D_loss: 0.363, G_loss: 1.875
Epoch: [46/100]: D_loss: 0.485, G_loss: 2.011
Epoch: [47/100]: D_loss: 0.391, G_loss: 1.994
Epoch: [48/100]: D_loss: 0.331, G_loss: 1.924
Epoch: [49/100]: D_loss: 0.317, G_loss: 1.930
Epoch: [50/100]: D_loss: 0.353, G_loss: 2.035
Epoch: [51/100]: D_loss: 0.334, G_loss: 2.072
Epoch: [52/100]: D_loss: 0.387, G_loss: 2.092
Epoch: [53/100]: D_loss: 0.380, G_loss: 2.139
Epoch: [54/100]: D_loss: 0.302, G_loss: 2.077
Epoch: [55/100]: D_loss: 0.311, G_loss: 2.055
Epoch: [56/100]: D_loss: 0.326, G_loss: 2.169
Epoch: [57/100]: D_loss: 0.309, G_loss: 2.239
Epoch: [58/100]: D_loss: 0.323, G_loss: 2.207
Epoch: [59/100]: D_loss: 0.285, G_loss: 2.239
Epoch: [60/100]: D_loss: 0.306, G_loss: 2.304
Epoch: [61/100]: D_loss: 0.287, G_loss: 2.254
Epoch: [62/100]: D_loss: 0.295, G_loss: 2.406
Epoch: [63/100]: D_loss: 0.305, G_loss: 2.499
Epoch: [64/100]: D_loss: 0.298, G_loss: 2.462
Epoch: [65/100]: D_loss: 0.255, G_loss: 2.418
Epoch: [66/100]: D_loss: 0.480, G_loss: 2.714
Epoch: [67/100]: D_loss: 0.265, G_loss: 2.379
Epoch: [68/100]: D_loss: 0.256, G_loss: 2.453
Epoch: [69/100]: D_loss: 0.252, G_loss: 2.465
Epoch: [70/100]: D_loss: 0.240, G_loss: 2.600
Epoch: [71/100]: D_loss: 0.250, G_loss: 2.516
Epoch: [72/100]: D_loss: 0.228, G_loss: 2.534
Epoch: [73/100]: D_loss: 0.249, G_loss: 2.566
Epoch: [74/100]: D_loss: 0.385, G_loss: 2.915
Epoch: [75/100]: D_loss: 0.232, G_loss: 2.566
Epoch: [76/100]: D_loss: 0.335, G_loss: 2.776
Epoch: [77/100]: D_loss: 0.243, G_loss: 2.703
Epoch: [78/100]: D_loss: 0.232, G_loss: 2.650
Epoch: [79/100]: D_loss: 0.216, G_loss: 2.736
Epoch: [80/100]: D_loss: 0.219, G_loss: 2.725
Epoch: [81/100]: D_loss: 0.272, G_loss: 2.869
Epoch: [82/100]: D_loss: 0.218, G_loss: 2.839
Epoch: [83/100]: D_loss: 0.219, G_loss: 2.836
Epoch: [84/100]: D_loss: 0.233, G_loss: 2.948
Epoch: [85/100]: D_loss: 0.209, G_loss: 2.952
Epoch: [86/100]: D_loss: 0.251, G_loss: 3.052
Epoch: [87/100]: D_loss: 0.198, G_loss: 2.905
Epoch: [88/100]: D_loss: 0.193, G_loss: 3.054
Epoch: [89/100]: D_loss: 0.215, G_loss: 2.995
Epoch: [90/100]: D_loss: 0.193, G_loss: 3.081
Epoch: [91/100]: D_loss: 0.446, G_loss: 3.269
Epoch: [92/100]: D_loss: 0.227, G_loss: 2.871
Epoch: [93/100]: D_loss: 0.191, G_loss: 3.008
Epoch: [94/100]: D_loss: 0.200, G_loss: 3.066
Epoch: [95/100]: D_loss: 0.200, G_loss: 3.142
Epoch: [96/100]: D_loss: 0.186, G_loss: 3.113
Epoch: [97/100]: D_loss: 0.207, G_loss: 3.159
Epoch: [98/100]: D_loss: 0.219, G_loss: 3.213
Epoch: [99/100]: D_loss: 0.177, G_loss: 3.205
Epoch: [100/100]: D_loss: 0.184, G_loss: 3.258

4. 可视化

4.1.LOSS图

G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出图像为:

 

4.2.生成指定图像

from numpy.random import randn

generator.load_state_dict(torch.load("./training_weights/generator_epoch_100.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出图像为:

四、理论基础

        CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

        CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

        CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. 有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. 联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. 可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. 使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

        相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. 条件信息的输入:CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. 训练稳定性:传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. 损失函数:虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. 网络结构:在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

        CGAN网络结构如下图所示:                由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

        综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值