昇思25天学习打卡营第42天|生成式-Pix2Pix实现图像转换
Pix2Pix实现图像转换(cGAN,条件生成对抗网络)
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。
生成器(Generator):在cGAN中,生成器的输入是随机噪声向量(通常用 z 表示)和条件信息 y。生成器试图学习从这些输入到目标数据分布的映射,以生成符合条件 y 的数据样本 G(z∣y)。
判别器(Discriminator):判别器的输入是生成的数据样本和条件信息,以及真实的数据样本和条件信息。判别器的目标是区分由生成器生成的数据和真实数据,即输出一个概率值 D(x∣y),表示样本 x 在给定条件 y下是否是真实数据的概率。
cGAN与传统GAN比较
输入的区别:
- 传统GAN:生成器只接收随机噪声 z作为输入,判别器接收生成的数据和真实数据。
- cGAN:生成器和判别器都接收额外的条件信息 y。生成器的输入是 (z,y),判别器的输入是 (x,y) 。
输出的区别:
- 传统GAN:生成器输出是模拟真实数据分布的样本,判别器输出一个二元分类的概率。
- cGAN:生成器输出是符合条件 y 的样本,判别器输出是在条件 y下数据是否真实的概率。
优点:
- 可控性:cGAN可以通过条件信息控制生成数据的特定属性,如生成特定类型的图像、声音等。
- 多样性:通过不同的条件信息,cGAN可以生成多样化的样本。
缺点:
- 模式崩溃(Mode Collapse):即使在条件信息的指导下,生成器可能仍会倾向于生成类似的样本,导致生成样本缺乏多样性。后续可通过在训练和测试时都使用dropout,来产生多样化的结果。
- 训练不稳定:与传统GAN一样,cGAN的训练过程也可能出现不稳定,尤其是生成器和判别器的能力不均衡时。
基于PatchGAN的判别器
PatchGAN是一种特殊的判别器,它不是判断整幅图像的真伪,而是判断图像中每个局部区域(patch)的真伪。这样有助于提升模型的细节辨别能力,并在处理高分辨率图像时更加高效。
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1)
output = self.features(x_y)
return output
ConvNormRelu
类:用于构建一个包含卷积层、归一化层和激活层的序列。它的作用是简化卷积块的构建。
Discriminator
类:继承自 nn.Cell
。它实现了PatchGAN的架构。
训练过程
- 前向传播的生成和判别:生成器接收一个随机噪声向量和条件信息(如标签、图像、文本等),并通过网络生成一个假样本(例如图像);判别器接收生成的假样本与条件信息的组合,以及真实样本与条件信息的组合,并输出一个判别结果。
- 计算并更新判别器:将真实数据与条件信息输入判别器,判别器应该输出一个高概率表示真实;将生成的数据与条件信息输入判别器,判别器应该输出一个低概率表示虚假。这两部分损失的总和构成了判别器的总损失函数。然后,使用反向传播算法计算损失相对于判别器参数的梯并更新判别器参数。
- 计算并更新生成器:生成器的损失也基于判别器的输出,但与判别器的目标相反。它的损失是判别器对生成数据输出的预测与真实标签(通常为1)之间的误差。接着使用反向传播计算生成器损失相对于生成器参数的梯度,最后更新参数。
- 迭代:重复上述步骤,直至50%概率、到达一定的epoch或者数据质量达到一定的指标(如用FID分数来评估)。
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
lrs = [lr] * dataset_size * n_epochs
lr_epoch = 0
for epoch in range(n_epochs_decay):
lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
lrs += [lr_epoch] * dataset_size
lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forword_dis(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forword_gan(reala, realb):
lambda_gan = 0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
loss_1 = loss_f(pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d(reala, realb)
loss_gan, g_grads = grad_g(reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)
for epoch in range(epoch_num):
for i, data in enumerate(data_loader):
start_time = datetime.datetime.now()
input_image = Tensor(data["input_images"])
target_image = Tensor(data["target_images"])
dis_loss, gen_loss = train_step(input_image, target_image)
end_time = datetime.datetime.now()
delta = (end_time - start_time).microseconds
if i % 2 == 0:
print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
d_losses.append(dis_loss.asnumpy())
g_losses.append(gen_loss.asnumpy())
if (epoch + 1) == epoch_num:
mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")