以下为官方活动的学习笔记兼打卡记录,大部分内容来自活动资料,稍有删改,以我自己能看懂为准。
一、Pix2Pix简介
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。
传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。
基础原理
cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。
在教程开始前,首先定义一些在整个过程中需要用到的符号:
- x x x:代表观测图像的数据。
- z z z:代表随机噪声的数据。
- y = G ( x , z ) y=G(x,z) y=G(x,z):生成器网络,给出由观测图像 x x x与随机噪声 z z z生成的“假”图片,其中 x x x来自于训练数据而非生成器。
- D ( x , G ( x , z ) ) D(x,G(x,z)) D(x,G(x,z)):判别器网络,给出图像判定为真实图像的概率,其中 x x x来自于训练数据, G ( x , z ) G(x,z) G(x,z)来自于生成器。
cGAN的目标可以表示为:
L c G A N ( G , D ) = E ( x , y ) [ l o g ( D ( x , y ) ) ] + E ( x , z ) [ l o g ( 1 − D ( x , G ( x , z ) ) ) ] L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))] LcGAN(G,D)=E(x,y)[log(D(x,y))]+E(x,z)[log(1−D(x,G(x,z)))]
该公式是cGAN的损失函数,D
想要尽最大努力去正确分类真实图像与“假”图像,也就是使参数
l
o
g
D
(
x
,
y
)
log D(x,y)
logD(x,y)最大化;而G
则尽最大努力用生成的“假”图像
y
y
y欺骗D
,避免被识破,也就是使参数
l
o
g
(
1
−
D
(
G
(
x
,
z
)
)
)
log(1−D(G(x,z)))
log(1−D(G(x,z)))最小化。cGAN的目标可简化为:
a r g min G max D L c G A N ( G , D ) arg\min_{G}\max_{D}L_{cGAN}(G,D) argGminDmaxLcGAN(G,D)
为了对比cGAN和GAN的不同,我们将GAN的目标也进行了说明:
L G A N ( G , D ) = E y [ l o g ( D ( y ) ) ] + E ( x , z ) [ l o g ( 1 − D ( x , z ) ) ] L_{GAN}(G,D)=E_{y}[log(D(y))]+E_{(x,z)}[log(1-D(x,z))] LGAN(G,D)=Ey[log(D(y))]+E(x,z)[log(1−D(x,z))]
从公式可以看出,GAN直接由随机噪声 z z z生成“假”图像,不借助观测图像 x x x的任何信息。过去的经验告诉我们,GAN与传统损失混合使用是有好处的,判别器的任务不变,依旧是区分真实图像与“假”图像,但是生成器的任务不仅要欺骗判别器,还要在传统损失的基础上接近训练数据。假设cGAN与L1正则化混合使用,那么有:
L L 1 ( G ) = E ( x , y , z ) [ ∣ ∣ y − G ( x , z ) ∣ ∣ 1 ] L_{L1}(G)=E_{(x,y,z)}[||y-G(x,z)||_{1}] LL1(G)=E(x,y,z)[∣∣y−G(x,z)∣∣1]
进而得到最终目标:
a r g min G max D L c G A N ( G , D ) + λ L L 1 ( G ) arg\min_{G}\max_{D}L_{cGAN}(G,D)+\lambda L_{L1}(G) argGminDmaxLcGAN(G,D)+λLL1(G)
图像转换问题本质上其实就是像素到像素的映射问题,Pix2Pix使用完全一样的网络结构和目标函数,仅更换不同的训练数据集就能分别实现以上的任务。本任务将借助MindSpore框架来实现Pix2Pix的应用。
二、数据集
在本教程中,我们将使用指定数据集,该数据集是已经经过处理的外墙(facades)数据,可以直接使用mindspore.dataset的方法读取。
三、网络结构
生成器G用到的是U-Net结构,输入的轮廓图 x x x编码再解码成真实图片,判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图 x x x的条件下,对于生成的图片 G ( x ) G(x) G(x)判断为假,对于真实判断为真。
3.1 生成器
原始cGAN的输入是条件x和噪声z两种信息,这里的生成器只使用了条件信息,因此不能生成多样性的结果。因此Pix2Pix在训练和测试时都使用了dropout,这样可以生成多样性的结果。
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
class UNetSkipConnectionBlock(nn.Cell):
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
super(UNetSkipConnectionBlock, self).__init__()
down_norm = nn.BatchNorm2d(inner_nc)
up_norm = nn.BatchNorm2d(outer_nc)
use_bias = False
if norm_mode == 'instance':
down_norm = nn.BatchNorm2d(inner_nc, affine=False)
up_norm = nn.BatchNorm2d(outer_nc, affine=False)
use_bias = True
if in_planes is None:
in_planes = outer_nc
down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
down_relu = nn.LeakyReLU(alpha)
up_relu = nn.ReLU()
if outermost:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, pad_mode='pad')
down = [down_conv]
up = [up_relu, up_conv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv]
up = [up_relu, up_conv, up_norm]
model = down + up
else:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv, down_norm]
up = [up_relu, up_conv, up_norm]
model = down + [submodule] + up
if dropout:
model.append(nn.Dropout(p=0.5))
self.model = nn.SequentialCell(model)
self.skip_connections = not outermost
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = ops.concat((out, x), axis=1)
return out
class UNetGenerator(nn.Cell):
def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
super(UNetGenerator, self).__init__()
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
norm_mode=norm_mode, innermost=True)
for _ in range(n_layers - 5):
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode, dropout=dropout)
unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
outermost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
3.2 判别器
判别器使用的PatchGAN结构,可看做卷积。生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1)
output = self.features(x_y)
return output
四、训练与推理
4.1 损失函数与优化器
def forword_dis(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forword_gan(reala, realb):
lambda_gan = 0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
loss_1 = loss_f(pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
4.2 推理结果
训练过程只持续了3个epoch,获取训练过程完成后的ckpt文件,通过load_checkpoint和load_param_into_net将ckpt中的权重参数导入到模型中,获取数据进行推理并对推理的效果图进行演示。
理想状况下,各数据集推理效果如下: