G3 - 手势图像生成 CGAN入门



理论知识

CGAN(条件生成对抗网络)是在GAN(生成对抗网络)的基础上进行了一些改进。对于原始的GAN生成器而言,用来生成图像的数据是随机不可预测的,因此没有办法控制网络的输出,在实际操作中的可控性不强。

针对原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了CGAN,通过给原始GAN中的生成器G和判别器D增加额外的条件,来把无监督学习的GAN转化为有监督学习的CGAN,便于网络能够在我们的掌控下更好地进行训练。

例如:我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。

CGAN的本质就是将额外的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别 ,人脸表情和其他辅助信息等。

网络结构如图所示:
CGAN网络结构
从图中的网络结构可知,条件信息y作为额外的输入被引入到GAN中,与生成器中的噪声z合并作为隐含层的表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的许多研究中被证明是非常有效的,为后续的相关工作提供了积极的指导作用。

环境

  • Python 3.11
  • GTX 4090
  • Pytorch 2.1.0

步骤

环境设置

包引用

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt

创建一个全局的设备对象,和批次大小

# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 批次大小
batch_size = 128

数据准备

导入数据

transform = transforms.Compose([
	transforms.Resize(128),
	transforms.ToTensor(),
	transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = datasets.ImageFolder(
	root='data/rps', transform=transform)

train_loader = DataLoader(dataset=train_dataset,
						batch_size=batch_size,
						shuffle=True,
						num_workers=6)

查看数据集中的数据

def show_images(images):
	"""把图像组合成一个网络,并展示"""
	plt.figure(figsize=(20, 20)
	plt.axis('off')
	plt.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
	"""在数据库中取一个批次的数据进行展示"""
	for images, _ in dl:
		show_images(images)
		break

show_batch(train_loader)

数据集展示

模型设计

首先设置一下模型输入输出 ,隐藏层的参数

# 图像的形状
image_shape = (3, 128, 128)
# 扯平后的维度
image_dim = int(np.prod(image_shape))
# 隐藏层的维度
latent_dim = 100

# 分类数量 剪刀 石头 布
n_classes = 3
# 嵌入维度
embedding_dim = 100

编写模型的初始化函数

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)

构建生成器

class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		# 条件标签生成器,用来将标签映射到嵌入空间中
		self.label_conditioned_generator = nn.Sequential(
			nn.Embedding(n_classes, embedding_dim), # 使用Embedding层,将条件标签映射为稠密向量
			nn.Linear(embedding_dim, 16) # 使用线性层将稠密向量转换为更好维度
		)
		# 潜在向量生成器,用于将噪声向量映射到图像空间中
		self.latent = nn.Sequential(
			nn.Linear(latent_dim, 4*4*512), # 使用线性层将潜在向量转换为更高维度
			nn.LeakyReLU(0.2, inplace=True)
		)
		# 生成器的主要结构,将条件标签和潜在向量合并,生成图像
		self.model = nn.Sequential(
			# 反卷积层1:将合并后的向量映射为64*8*8的特征图
			nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
			nn.ReLU(True),
			# 反卷积层2:将64*8*8的特征图映射为64*4*4的特征图
			nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
			nn.ReLU(True),
			# 反卷积层3:将64*4*4的特征图映射为64*2*2的特征图
			nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
			nn.ReLU(True),
			# 反卷积层4:将64*2*2的特征图映射为64*1*1的特征图
			nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层5:将64*1*1的特征图映射为3*64*64的特征图
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
			nn.Tanh()
		)
	def forward(self, inputs):
		# 解构输入
		noise_vector, label = inputs
		# 计算标签嵌入向量
		label_output = self.label_conditioned_generator(label)
		# 将嵌入向量的形状变为(batch_size, 1, 4, 4)
		label_output = label_output.view(-1, 1, 4, 4)
		# 将噪声微量映射为潜在向量
		latent_output = self.latent(noise_vector)
		# 将潜在向量的形状变为(batch_size, 512, 4, 4)
		latent_output = latent_output.view(-1, 512, 4, 4)
		# 合并条件标签和潜在向量
		concat = torch.cat((latent_output, label_output), dim=1)
		# 通过合并后的特征图生成RGB图像
		image = self.model(concat)
		return image

创建生成器

# 创建生成器
generator = Generator().to(device)
# 初始化
generator.apply(weights_init)
# 打印生成器结构
summary(generator)

生成器结构输出

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

测试生成器的输出格式

a = torch.ones(100)
b = torch.ones(1)

b = b.long()
a = a.to(device)
b = b.to(device)

c = generator((a, b))
c.size()

输出:torch.Size([1, 3, 128, 128])

构建鉴别器

class Discriminator(nn.Module):
	def __init__(self):
		super().__init__()
		self.label_conditioned_disc = nn.Sequential(
			nn.Embedding(n_classes, embedding_dim), # 嵌入层将类别标签编码为固定长度的向量
			nn.Linear(embedding_dim, 3*128*128) # 线性层将嵌入向量转换为图像尺寸相匹配的向量
		)
		self.model = nn.Sequential(
			nn.Conv2d(6, 64, 4, 2, 1, bias=False), # 输入通道为6(包含图像和标签)
			nn.LeakyReLU(0.2, inplace=True),
			nn.Conv2d(64,  64*2, 4, 3, 2, bias=False),
			nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),
			nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),
			nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Flatten(), # 将特征图展平为一维向量,用于后续的全连接层
			nn.Dropout(0.4), # 随机失活,用于减少过拟合风险
			nn.Linear(4608, 1), # 全连接层,将特征向量映射到输出维度为1的向量
			nn.Sigmoid() # 用于将输出范围限制到0到1之间的概率值
		)
	def forward(self, inputs):
		img, label = inputs
        
		# 将类别标签转换为特征向量
		label_output = self.label_conditioned_disc(label)
		# 重塑特征向量为与图像尺寸相匹配的特征向量
		label_output = label_output.view(-1, 3, 128, 128)
		# 合并图像与特征向量
		concat = torch.cat((img, label_output), dim=1)
		# 向前传播
		output = self.model(concat)
		return output

创建鉴别器

# 创建鉴别器
discriminator = Discriminator().to(device)
# 初始化权重 
discriminator.apply(weights_init)
# 打印鉴别器结构
summary(discriminator)

鉴别器结构

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

测试判别器输出结构格式

a = torch.ones(2, 3, 128, 128)
b = torch.ones(2, 1)

b = b.long()
a = a.to(device)
b = b.to(device)

c = discriminator((a, b))
c.size()

输出:torch.Size([2, 1])

模型训练

定义损失函数

adversarial_loss = nn.BCELoss()

# 生成损失
def generator_loss(fake_output, label):
	gen_loss = adversarial_loss(fake_output, label)
	return gen_loss

# 鉴别损失
def discriminator_loss(output, label):
	disc_loss = adversarial_loss(output, label)
	return disc_loss

定义优化器

learning_rate  = 0.0002

G_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

GAN与CGAN结构区别

开始训练

# 设置训练总轮次
num_epochs = 100
# 初始化用于存储每轮训练中鉴别损失和生成损失的列表
D_loss_plot, G_loss_plot = [], []

# 循环训练
for epoch in range(1, num_epochs + 1):
	# 初始化每轮训练中判别器和生成器损失的临时列表
	D_loss_list, G_loss_list = [], []
	# 遍历训练数据加载器中的数据
	for index, (real_images, labels) in enumerate(train_loader):
		# 清空判别器梯度缓存
		D_optimizer.zero_grad()
		# 将数据加载GPU
		real_images = real_images.to(device)
		labels = labels.to(device)

		# 将标签的形状从一维向量转换为二维
		labels = labels.unsqueeze(1).long()
		# 创建真实目标和虚拟目标的向量
		real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
		fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

		# 计算判别器对真实图像的损失
		D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

		# 从噪声向量中生成假图像
		noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
		generated_image = generator((noise_vector, labels))

		# 计算判别器对假图像的损失
		# 注意这里一定要detach,用于分离生成器的梯度计算图
		output = discriminator((generated_image.detach(), labels))
		D_fake_loss = discriminator_loss(output, fake_target)

		# 计算判别器总体损失
		D_total_loss = (D_real_loss + D_fake_loss) / 2
		D_loss_list.append(D_total_loss)

		# 反向传播更新判别器参数
		D_total_loss.backward()
		D_optimizer.step()

		# 清空生成器的梯度缓存
		G_optimizer.zero_grad()
		# 计算生成器的损失
		G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
		G_loss_list.append(G_loss)

		# 反向传播更新生成器的参数
		G_loss.backward()
		G_optimizer.step()

	# 打印当前轮次的判别器和生成器的平均损失
	D_batch_loss = torch.mean(torch.FloatTensor(D_loss_list))
	G_batch_loss = torch.mean(torch.FloatTensor(G_loss_list))
	print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ( epoch, num_epochs, D_batch_loss, G_batch_loss))
	# 将当前轮次的判别器和生成器的平均损失保存到列表中
	D_loss_plot.append(D_batch_loss)
	G_loss_plot.append(G_batch_loss)

	if epoch % 10 == 0:
		# 将生成的假图像保存为图片文件
		save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow = 5, normalize=True)
		torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % epoch)
		torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % epoch)

训练过程输出

Epoch: [1/100]: D_loss: 0.337, G_loss: 1.425
Epoch: [2/100]: D_loss: 0.204, G_loss: 2.700
Epoch: [3/100]: D_loss: 0.319, G_loss: 2.214
Epoch: [4/100]: D_loss: 0.274, G_loss: 2.075
Epoch: [5/100]: D_loss: 0.324, G_loss: 2.306
Epoch: [6/100]: D_loss: 0.409, G_loss: 2.323
Epoch: [7/100]: D_loss: 0.434, G_loss: 1.669
Epoch: [8/100]: D_loss: 0.419, G_loss: 1.580
Epoch: [9/100]: D_loss: 0.366, G_loss: 1.750
Epoch: [10/100]: D_loss: 0.429, G_loss: 1.798
Epoch: [11/100]: D_loss: 0.486, G_loss: 2.160
Epoch: [12/100]: D_loss: 0.572, G_loss: 1.873
Epoch: [13/100]: D_loss: 0.476, G_loss: 1.592
Epoch: [14/100]: D_loss: 0.435, G_loss: 1.692
Epoch: [15/100]: D_loss: 0.432, G_loss: 1.862
Epoch: [16/100]: D_loss: 0.438, G_loss: 1.735
Epoch: [17/100]: D_loss: 0.459, G_loss: 1.726
Epoch: [18/100]: D_loss: 0.423, G_loss: 1.846
Epoch: [19/100]: D_loss: 0.387, G_loss: 1.859
Epoch: [20/100]: D_loss: 0.422, G_loss: 2.201
Epoch: [21/100]: D_loss: 0.398, G_loss: 2.060
Epoch: [22/100]: D_loss: 0.370, G_loss: 2.159
Epoch: [23/100]: D_loss: 0.388, G_loss: 2.127
Epoch: [24/100]: D_loss: 0.362, G_loss: 2.134
Epoch: [25/100]: D_loss: 0.403, G_loss: 2.207
Epoch: [26/100]: D_loss: 0.374, G_loss: 1.892
Epoch: [27/100]: D_loss: 0.459, G_loss: 1.837
Epoch: [28/100]: D_loss: 0.438, G_loss: 1.661
Epoch: [29/100]: D_loss: 0.460, G_loss: 1.613
Epoch: [30/100]: D_loss: 0.468, G_loss: 1.535
Epoch: [31/100]: D_loss: 0.434, G_loss: 1.467
Epoch: [32/100]: D_loss: 0.501, G_loss: 1.583
Epoch: [33/100]: D_loss: 0.435, G_loss: 1.475
Epoch: [34/100]: D_loss: 0.458, G_loss: 1.505
Epoch: [35/100]: D_loss: 0.458, G_loss: 1.523
Epoch: [36/100]: D_loss: 0.460, G_loss: 1.524
Epoch: [37/100]: D_loss: 0.451, G_loss: 1.541
Epoch: [38/100]: D_loss: 0.395, G_loss: 1.508
Epoch: [39/100]: D_loss: 0.510, G_loss: 1.633
Epoch: [40/100]: D_loss: 0.427, G_loss: 1.553
Epoch: [41/100]: D_loss: 0.448, G_loss: 1.635
Epoch: [42/100]: D_loss: 0.435, G_loss: 1.638
Epoch: [43/100]: D_loss: 0.433, G_loss: 1.559
Epoch: [44/100]: D_loss: 0.426, G_loss: 1.620
Epoch: [45/100]: D_loss: 0.414, G_loss: 1.616
Epoch: [46/100]: D_loss: 0.555, G_loss: 1.830
Epoch: [47/100]: D_loss: 0.377, G_loss: 1.677
Epoch: [48/100]: D_loss: 0.488, G_loss: 1.752
Epoch: [49/100]: D_loss: 0.394, G_loss: 1.680
Epoch: [50/100]: D_loss: 0.432, G_loss: 1.725
Epoch: [51/100]: D_loss: 0.370, G_loss: 1.673
Epoch: [52/100]: D_loss: 0.587, G_loss: 2.044
Epoch: [53/100]: D_loss: 0.379, G_loss: 1.760
Epoch: [54/100]: D_loss: 0.388, G_loss: 1.741
Epoch: [55/100]: D_loss: 0.395, G_loss: 1.739
Epoch: [56/100]: D_loss: 0.385, G_loss: 1.848
Epoch: [57/100]: D_loss: 0.417, G_loss: 1.840
Epoch: [58/100]: D_loss: 0.383, G_loss: 1.775
Epoch: [59/100]: D_loss: 0.371, G_loss: 1.804
Epoch: [60/100]: D_loss: 0.375, G_loss: 1.824
Epoch: [61/100]: D_loss: 0.441, G_loss: 1.999
Epoch: [62/100]: D_loss: 0.405, G_loss: 1.964
Epoch: [63/100]: D_loss: 0.370, G_loss: 1.950
Epoch: [64/100]: D_loss: 0.373, G_loss: 1.878
Epoch: [65/100]: D_loss: 0.347, G_loss: 1.899
Epoch: [66/100]: D_loss: 0.402, G_loss: 2.059
Epoch: [67/100]: D_loss: 0.343, G_loss: 1.911
Epoch: [68/100]: D_loss: 0.499, G_loss: 2.258
Epoch: [69/100]: D_loss: 0.347, G_loss: 1.970
Epoch: [70/100]: D_loss: 0.325, G_loss: 1.994
Epoch: [71/100]: D_loss: 0.356, G_loss: 2.073
Epoch: [72/100]: D_loss: 0.373, G_loss: 2.078
Epoch: [73/100]: D_loss: 0.326, G_loss: 2.036
Epoch: [74/100]: D_loss: 0.333, G_loss: 2.109
Epoch: [75/100]: D_loss: 0.502, G_loss: 2.346
Epoch: [76/100]: D_loss: 0.313, G_loss: 2.135
Epoch: [77/100]: D_loss: 0.342, G_loss: 2.142
Epoch: [78/100]: D_loss: 0.573, G_loss: 2.343
Epoch: [79/100]: D_loss: 0.297, G_loss: 2.079
Epoch: [80/100]: D_loss: 0.296, G_loss: 2.111
Epoch: [81/100]: D_loss: 0.299, G_loss: 2.192
Epoch: [82/100]: D_loss: 0.271, G_loss: 2.149
Epoch: [83/100]: D_loss: 0.318, G_loss: 2.294
Epoch: [84/100]: D_loss: 0.359, G_loss: 2.283
Epoch: [85/100]: D_loss: 0.279, G_loss: 2.244
Epoch: [86/100]: D_loss: 0.339, G_loss: 2.404
Epoch: [87/100]: D_loss: 0.310, G_loss: 2.343
Epoch: [88/100]: D_loss: 0.430, G_loss: 2.496
Epoch: [89/100]: D_loss: 0.267, G_loss: 2.323
Epoch: [90/100]: D_loss: 0.254, G_loss: 2.371
Epoch: [91/100]: D_loss: 0.505, G_loss: 2.713
Epoch: [92/100]: D_loss: 0.270, G_loss: 2.389
Epoch: [93/100]: D_loss: 0.276, G_loss: 2.396
Epoch: [94/100]: D_loss: 0.334, G_loss: 2.482
Epoch: [95/100]: D_loss: 0.256, G_loss: 2.432
Epoch: [96/100]: D_loss: 0.251, G_loss: 2.434
Epoch: [97/100]: D_loss: 0.294, G_loss: 2.523
Epoch: [98/100]: D_loss: 0.282, G_loss: 2.487
Epoch: [99/100]: D_loss: 0.424, G_loss: 2.816
Epoch: [100/100]: D_loss: 0.303, G_loss: 2.609

模型效果展示

使用训练后的模型生成图像

# 生成潜在空间的点,作为生成器的输入
def generate_latent_points(latent_dim, n_samples, n_classes=3):
    # 从标准正态分布中生成潜在空间的点
    x_input = np.random.randn(latent_dim * n_samples)
    # 将生成的点整形成用于神经网络的输入的批量
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

# 在两个潜在空间点之间进行均匀插值
def interpolate_points(p1, p2, n_steps=10):
    # 在两个点之间进行插值,生成插值比率
    ratios = np.linspace(0, 1, num=n_steps)
    # 线性插值向量
    vectors = list()
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return np.asarray(vectors)

# 生成两个潜在空间的点
pts = generate_latent_points(100, 2)
# 在两个潜在空间点之间进行插值
interpolated = interpolate_points(pts[0], pts[1])

# 将数据转换为torch张量,并移动到GPU
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

output = None
# 对于三个类别循环,分别进行插值和生成图片
for label in range(3):
    # 创建包含相同类别标签的张量
    labels = torch.ones(10)*label
    labels = labels.to(device)
    labels = labels.unsqueeze(1).long()
    print(labels.size())
    # 使用生成器生成插值的结果
    predictions = generator((interpolated, labels))
    predictions = predictions.permute(0, 2, 3, 1)
    pred = predictions.detach().cpu()
    if output is None:
        output = pred
    else:
        output = np.concatenate((output, pred))

展示生成的图像,这里没有用make_grid用了matplotlib的子图

from matplotlib import gridspec
nrow = 3
ncol = 10

plt.figure(figsize=(15, 4))
gs = gridspec.GridSpec(nrow, ncol)

k = 0

for i in range(nrow):
    for j in range(ncol):
        pred = (output[k, :, :, :] + 1 ) *127.5
        pred = np.array(pred)
        plt.subplot(gs[i, j])
        plt.axis('off')
        plt.imshow(pred.astype(np.uint8))
        k += 1

模型结果
关于如何生成指定标签的图像

# 生成一个潜在向量,和一个标签,传入模型就可以生成一个图像
def generate_image(class_idx):
    latent = torch.tensor(np.random.randn(latent_dim)).type(torch.float32).unsqueeze(0).to(device)
    labels = (torch.ones(1)*class_idx).unsqueeze(0).long().to(device)
    predictions = generator((latent, labels)).permute(0, 2, 3, 1).detach().cpu()
    return predictions[0]

# 标签0对应的是布,重复调次,都是生成的布
image = generate_image(0)
plt.axis('off')
plt.imshow(image)

生成结果

总结与心得体会

通过本次实验的学习,我学习到了如何控制GAN的生成,通过将标签通过嵌入生成嵌入向量,然后合并到随机生成的潜在向量中传入生成器,就可以生成出与标签对应的结果

还有一个在上个实验中我感觉不太对,但是没有找到原因的地方:在判别器反向传播的时候使用detach()断开和生成器的连接,防止互相影响。上次实验是分别对判别器的real_loss和fake_loss进行反向传播,此次实验是将real_loss和fake_loss相加后,统一进行的反向传播。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值