第G4周:CGAN|生成手势图像|可控制生成

一、前置知识

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. 有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. 联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. 可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. 使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. 条件信息的输入:CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. 训练稳定性:传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. 损失函数:虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. 网络结构:在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它
提供了更明确的训练目标和更好的生成效果。

二、准备工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

输出
device(type=‘cuda’)

1.导入数据

batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=6)

2.数据可视化

def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出
在这里插入图片描述

三、构建模型

```python
latent_dim = 100
n_classes = 3
embedding_dim = 100

1.初始化权重

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

2.构建生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim), 
            nn.Linear(embedding_dim, 16)            
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  
            nn.LeakyReLU(0.2, inplace=True)  
        )

        self.model = nn.Sequential( 
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  
            nn.ReLU(True),            
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),     
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  
        )

    def forward(self, inputs):
        noise_vector, label = inputs  
        label_output = self.label_conditioned_generator(label)     
        label_output = label_output.view(-1, 1, 4, 4)        
        latent_output = self.latent(noise_vector)     
        latent_output = latent_output.view(-1, 512, 4, 4) 
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

输出
在这里插入图片描述

from torchinfo import summary
summary(generator)

输出
在这里插入图片描述

a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

3.构建鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     
            nn.Linear(embedding_dim, 3*128*128)         
        )
        
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),      
            nn.LeakyReLU(0.2, inplace=True),             
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               
            nn.Dropout(0.4),                            
            nn.Linear(4608, 1),                         
            nn.Sigmoid()                                
        )

    def forward(self, inputs):
        img, label = inputs
        
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)
        
        concat = torch.cat((img, label_output), dim=1)
        
        output = self.model(concat)
        return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

输出
在这里插入图片描述

summary(discriminator)

输出
在这里插入图片描述

a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()

输出
torch.Size([2, 1])

四、训练模型

1.定义损失函数

adversarial_loss = nn.BCELoss() 

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2.定义优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3.训练模型

num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):
    
    D_loss_list, G_loss_list = [], []
    
    for index, (real_images, labels) in enumerate(train_loader):
        
        D_optimizer.zero_grad()
        
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        
        labels = labels.unsqueeze(1).long()
        
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch%10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出
在这里插入图片描述
在这里插入图片描述

五、模

1.Loss图

G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]
import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出
在这里插入图片描述

2.生成指定图像

from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec

generator.load_state_dict(torch.load("H:\generator_epoch_300.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

代码知识点
索引操作:predictions[0, :, :, :]表示从名为predictions的多维数组中获取第一个通道的所有数据。这里的:表示选择该维度上的所有元素。

数学运算:+ 1表示将每个元素加1,用于将预测结果的范围从-1到1映射到0到2之间。

数学运算:* 127.5表示将每个元素乘以127.5,用于将预测结果的范围从0到2映射到0到255之间,以便在图像显示时使用无符号整数类型(uint8)。

在深度学习中,尤其是在使用神经网络进行图像生成任务时,生成器网络的输出通常是一个介于-1到1之间的标准化浮点数张量。这种标准化是为了方便在训练过程中使用梯度下降算法,因为梯度下降算法对于小范围内的数值更加稳定和高效。

在许多情况下,生成器的最后一个层会使用双曲正切激活函数(Hyperbolic Tanh Activation
Function),其数学表达式为:

tanh ⁡ ( x ) = e 2 x − 1 e 2 x + 1 \tanh(x) = \frac{e^{2x} - 1}{e^{2x} + 1} tanh(x)=e2x+1e2x1

双曲正切函数将实值输入压缩到介于-1和1之间的输出。这意味着无论输入值是多少,输出值都会被限制在这个范围内。这种特性使得tanh激活函数在生成模型中非常有用,因为它可以确保生成的数据在特定的范围内,并且是零中心的(即平均值接近于0)。

在显示或处理这些生成的图像数据之前,需要将这些标准化的值转换回它们原来的像素值范围,通常是0到255的整数。这就是为什么代码中有+ 1的操作,它将tanh函数的输出范围从-1到1平移到0到2,然后通过乘以127.5将这个范围进一步扩展到0到255,这样就可以将浮点数转换为无符号整数(uint8)以便于图像显示。

总结一下,预测结果的范围是从-1到1的原因是因为使用了tanh激活函数,而将这个范围映射到0到255是为了能够在图像显示时正确地表示像素值。

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出
在这里插入图片描述

  • 7
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值