网络介绍
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。
下面介绍本网络用到的生成器、判别器:
生成器G用到的是U-Net结构,输入的轮廓图 𝑥编码再解码成真是图片,
判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图𝑥的条件下,对于生成的图片 𝐺(𝑥)判断为假,对于真实判断为真。
训练
训练分为两个主要部分:训练判别器和训练生成器。
训练判别器的目的是最大程度地提高判别图像真伪的概率。
训练生成器是希望能产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。部分实现如下:
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
lrs = [lr] * dataset_size * n_epochs
lr_epoch = 0
for epoch in range(n_epochs_decay):
lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
lrs += [lr_epoch] * dataset_size
lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forword_dis(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forword_gan(reala, realb):
lambda_gan = 0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
loss_1 = loss_f(pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d(reala, realb)
loss_gan, g_grads = grad_g(reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)
for epoch in range(epoch_num):
for i, data in enumerate(data_loader):
start_time = datetime.datetime.now()
input_image = Tensor(data["input_images"])
target_image = Tensor(data["target_images"])
dis_loss, gen_loss = train_step(input_image, target_image)
end_time = datetime.datetime.now()
delta = (end_time - start_time).microseconds
if i % 2 == 0:
print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
d_losses.append(dis_loss.asnumpy())
g_losses.append(gen_loss.asnumpy())
if (epoch + 1) == epoch_num:
mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
推理结果
此章节学习到此结束,感谢昇思平台。