深度学习Day-29:CGAN入门丨生成手势图像

  🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 条件生成对抗网络(CGAN)的基本原理
  2. CGAN是如何实现条件控制的
  3. 学习本文CGAN代码,并跑通代码
  4. 生成指定手势的图像(下周)

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1. 导入第三方库

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)

 2. 设置随机种子

运行下述代码:

torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

3. 导入数据

train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="H:/G3周数据集/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

4. 数据可视化

运行下述代码:

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        break

show_batch(train_loader)

输出图像为:

 

5. 定义超参数 

运行下述代码:

image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100
n_classes = 3
embedding_dim = 100

6. 构建模型

6.1.初始化权重

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

6.2.定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

以及:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

 4.3.定义鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)

以及:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

 三、 训练模型 

1. 定义训练参数

adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

得到如下输出:

Starting Training Loop...
[0/5][0/36]	Loss_D: 1.6825	Loss_G: 6.9805	D(x): 0.7608	D(G(z)): 0.6777 / 0.0017
[1/5][0/36]	Loss_D: 0.0191	Loss_G: 17.2766	D(x): 0.9984	D(G(z)): 0.0170 / 0.0000
[2/5][0/36]	Loss_D: 0.0036	Loss_G: 27.4451	D(x): 0.9967	D(G(z)): 0.0000 / 0.0000
[3/5][0/36]	Loss_D: 0.0001	Loss_G: 27.3443	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000
[4/5][0/36]	Loss_D: 0.0302	Loss_G: 17.4773	D(x): 0.9879	D(G(z)): 0.0000 / 0.0000

3. 训练模型

num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

Epoch: [1/100]: D_loss: 0.281, G_loss: 1.986
Epoch: [2/100]: D_loss: 0.244, G_loss: 2.649
Epoch: [3/100]: D_loss: 0.355, G_loss: 2.056
Epoch: [4/100]: D_loss: 0.275, G_loss: 2.004
Epoch: [5/100]: D_loss: 0.476, G_loss: 2.125
Epoch: [6/100]: D_loss: 0.316, G_loss: 2.227
Epoch: [7/100]: D_loss: 0.442, G_loss: 1.857
Epoch: [8/100]: D_loss: 0.391, G_loss: 1.779
Epoch: [9/100]: D_loss: 0.429, G_loss: 1.886
Epoch: [10/100]: D_loss: 0.362, G_loss: 2.152
Epoch: [11/100]: D_loss: 0.488, G_loss: 1.869
Epoch: [12/100]: D_loss: 0.498, G_loss: 1.991
Epoch: [13/100]: D_loss: 0.539, G_loss: 1.923
Epoch: [14/100]: D_loss: 0.543, G_loss: 1.552
Epoch: [15/100]: D_loss: 0.558, G_loss: 1.455
Epoch: [16/100]: D_loss: 0.532, G_loss: 1.418
Epoch: [17/100]: D_loss: 0.531, G_loss: 1.421
Epoch: [18/100]: D_loss: 0.502, G_loss: 1.496
Epoch: [19/100]: D_loss: 0.523, G_loss: 1.656
Epoch: [20/100]: D_loss: 0.431, G_loss: 1.458
Epoch: [21/100]: D_loss: 0.494, G_loss: 1.750
Epoch: [22/100]: D_loss: 0.447, G_loss: 1.683
Epoch: [23/100]: D_loss: 0.443, G_loss: 1.873
Epoch: [24/100]: D_loss: 0.461, G_loss: 1.924
Epoch: [25/100]: D_loss: 0.402, G_loss: 1.817
Epoch: [26/100]: D_loss: 0.462, G_loss: 1.758
Epoch: [27/100]: D_loss: 0.495, G_loss: 1.570
Epoch: [28/100]: D_loss: 0.501, G_loss: 1.440
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.382
Epoch: [30/100]: D_loss: 0.485, G_loss: 1.431
Epoch: [31/100]: D_loss: 0.478, G_loss: 1.402
Epoch: [32/100]: D_loss: 0.464, G_loss: 1.427
Epoch: [33/100]: D_loss: 0.442, G_loss: 1.559
Epoch: [34/100]: D_loss: 0.470, G_loss: 1.655
Epoch: [35/100]: D_loss: 0.418, G_loss: 1.629
Epoch: [36/100]: D_loss: 0.411, G_loss: 1.645
Epoch: [37/100]: D_loss: 0.426, G_loss: 1.654
Epoch: [38/100]: D_loss: 0.418, G_loss: 1.786
Epoch: [39/100]: D_loss: 0.381, G_loss: 1.830
Epoch: [40/100]: D_loss: 0.412, G_loss: 1.814
Epoch: [41/100]: D_loss: 0.371, G_loss: 1.734
Epoch: [42/100]: D_loss: 0.427, G_loss: 1.891
Epoch: [43/100]: D_loss: 0.378, G_loss: 1.751
Epoch: [44/100]: D_loss: 0.419, G_loss: 1.915
Epoch: [45/100]: D_loss: 0.374, G_loss: 1.813
Epoch: [46/100]: D_loss: 0.397, G_loss: 1.936
Epoch: [47/100]: D_loss: 0.367, G_loss: 1.856
Epoch: [48/100]: D_loss: 0.417, G_loss: 1.964
Epoch: [49/100]: D_loss: 0.383, G_loss: 1.904
Epoch: [50/100]: D_loss: 0.357, G_loss: 1.889
Epoch: [51/100]: D_loss: 0.458, G_loss: 2.058
Epoch: [52/100]: D_loss: 0.334, G_loss: 1.876
Epoch: [53/100]: D_loss: 0.328, G_loss: 1.872
Epoch: [54/100]: D_loss: 0.418, G_loss: 2.185
Epoch: [55/100]: D_loss: 0.349, G_loss: 1.978
Epoch: [56/100]: D_loss: 0.337, G_loss: 2.011
Epoch: [57/100]: D_loss: 0.334, G_loss: 2.069
Epoch: [58/100]: D_loss: 0.337, G_loss: 2.061
Epoch: [59/100]: D_loss: 0.352, G_loss: 2.201
Epoch: [60/100]: D_loss: 0.328, G_loss: 2.206
Epoch: [61/100]: D_loss: 0.333, G_loss: 2.176
Epoch: [62/100]: D_loss: 0.308, G_loss: 2.132
Epoch: [63/100]: D_loss: 0.454, G_loss: 2.441
Epoch: [64/100]: D_loss: 0.284, G_loss: 2.191
Epoch: [65/100]: D_loss: 0.273, G_loss: 2.232
Epoch: [66/100]: D_loss: 0.330, G_loss: 2.367
Epoch: [67/100]: D_loss: 0.285, G_loss: 2.294
Epoch: [68/100]: D_loss: 0.281, G_loss: 2.353
Epoch: [69/100]: D_loss: 0.278, G_loss: 2.357
Epoch: [70/100]: D_loss: 0.294, G_loss: 2.434
Epoch: [71/100]: D_loss: 0.311, G_loss: 2.411
Epoch: [72/100]: D_loss: 0.266, G_loss: 2.426
Epoch: [73/100]: D_loss: 0.296, G_loss: 2.532
Epoch: [74/100]: D_loss: 0.262, G_loss: 2.449
Epoch: [75/100]: D_loss: 0.263, G_loss: 2.516
Epoch: [76/100]: D_loss: 0.310, G_loss: 2.621
Epoch: [77/100]: D_loss: 0.258, G_loss: 2.595
Epoch: [78/100]: D_loss: 0.258, G_loss: 2.570
Epoch: [79/100]: D_loss: 0.391, G_loss: 2.698
Epoch: [80/100]: D_loss: 0.269, G_loss: 2.590
Epoch: [81/100]: D_loss: 0.238, G_loss: 2.629
Epoch: [82/100]: D_loss: 0.264, G_loss: 2.649
Epoch: [83/100]: D_loss: 0.224, G_loss: 2.706
Epoch: [84/100]: D_loss: 0.226, G_loss: 2.667
Epoch: [85/100]: D_loss: 0.250, G_loss: 2.836
Epoch: [86/100]: D_loss: 0.244, G_loss: 2.790
Epoch: [87/100]: D_loss: 0.284, G_loss: 2.830
Epoch: [88/100]: D_loss: 0.223, G_loss: 2.864
Epoch: [89/100]: D_loss: 0.233, G_loss: 2.853
Epoch: [90/100]: D_loss: 0.264, G_loss: 2.869
Epoch: [91/100]: D_loss: 0.210, G_loss: 2.863
Epoch: [92/100]: D_loss: 0.235, G_loss: 2.965
Epoch: [93/100]: D_loss: 0.206, G_loss: 2.944
Epoch: [94/100]: D_loss: 0.242, G_loss: 2.956
Epoch: [95/100]: D_loss: 0.232, G_loss: 2.955
Epoch: [96/100]: D_loss: 0.682, G_loss: 3.462
Epoch: [97/100]: D_loss: 0.207, G_loss: 2.843
Epoch: [98/100]: D_loss: 0.200, G_loss: 2.819
Epoch: [99/100]: D_loss: 0.197, G_loss: 2.965
Epoch: [100/100]: D_loss: 0.190, G_loss: 2.962

4. 模型分析

4.1.加载模型

generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()

4.2.分析模型

from numpy import asarray
from numpy.random import randn
from numpy import linspace
from matplotlib import gridspec

def generate_latent_points(latent_dim, n_samples, n_classes=3):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

def interpolate_points(p1, p2, n_steps=10):
    ratios = linspace(0, 1, num=n_steps)
    vectors = list()
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return asarray(vectors)

pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

output = None
for label in range(3):
    labels = torch.ones(10) * label
    labels = labels.to(device)
    labels = labels.unsqueeze(1).long()
    print(labels.size())
    predictions = generator((interpolated, labels))
    predictions = predictions.permute(0,2,3,1)
    pred = predictions.detach().cpu()
    if output is None:
        output = pred
    else:
        output = np.concatenate((output,pred))

print(output.shape)

nrow = 3
ncol = 10

fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol)

k = 0
for i in range(nrow):
    for j in range(ncol):
        pred = (output[k, :, :, :] + 1 ) * 127.5
        pred = np.array(pred)
        ax= plt.subplot(gs[i,j])
        ax.imshow(pred.astype(np.uint8))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.axis('off')
        k += 1

plt.show()

输出形状为:

torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1])
(30, 128, 128, 3)

输出图像为:

四、理论基础

        CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

        CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

        CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. 有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. 联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. 可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. 使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

        相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. 条件信息的输入:CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. 训练稳定性:传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. 损失函数:虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. 网络结构:在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

        CGAN网络结构如下图所示:                由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

        综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值