Image-to-Image Translation with Conditional GAN :像素级别图像到图像的变换

前言:

假设看到这篇文章的炼丹师对GAN有一定程度的了解,文章中提到GAN的一些基本概念不做深入探究,主要内容放在Conditonal GAN以及keras实现上,重点是算法的原理以及实现过程。掌握算法的原理可以解决为什么的问题即算法的本质,熟悉实现过程可以解决怎么用的问题即算法的工程化。

第一章 背景

这篇文章内容基于paper《Image-to-Image Translation with Conditional Adversarial Networks》。每一种算法的提出都是为了解决特定的问题,作者在paper中提到以往在处理输入图像到输出图像映射时都需要很多不同的特殊算法,虽然这些问题的实质都是像素到像素的映射问题。所以作者提出基于Conditional GAN的Image2image Translation算法,希望提供一个能够能够应用于不同场景的通用方法。

举一个例子,使用谷歌翻译或者百度翻译时只需要选择想要的源语言和目的语言就能够进行翻译,不用为了英译中、中译法每次选择不同的平台,作者希望做的类似这种翻译平台,只需要调整训练数据不用每次根据不同的任务调整算法结构,比如:semantic labels<-->photo、Architectural labels-->photo、Map<-->aerial、BW-->color photos、Sketch-->photo、Day-->night、Thermal→color photos、Photo with missing pixels→inpainted photo等。上面这些场景作者在paper中已经根据对应的数据集进行过验证,使用相同的架构,都能取得理想的效果。见图一。

图一

第二章 算法结构

GAN的基本结构:G(generator)即生成器以及D(discriminator)即辨识器,可以直观的认为G(生成器)生成与目标图像相似甚至是相同的图像,即无中生有且生成的和已有的基本没有区别,D(辨识器)则练就一双火眼金睛把G生成的图像和真实图像区别开,这也是GAN的对抗性所在,G尽可能的愚弄D让它不能作出正确的判断,D则要尽可能不被G愚弄,练就一双慧眼。下图二是对GAN过程的简单描述。

图二

Z代表随机噪声,训练一个Generator生成Fake images,然后利用Fake images和Real images训练Discriminator,由于GAN的对抗性,所以理论上GAN只有相对最优解而没有真正的最优解,这一点与检测以及分类任务不同。

Conditional GAN与一版GAN相比,多了一个condition,在condition的基础上学习图像到图像的变换,condition不一定是图像也可以是其他的形式,在这里只以图像为例。

图三 Conditional GAN

图三是Conditional GAN的过程解释简图,可以看到与一版GAN最明显的区别就是多了一个condition,即Generator中多了一个y,Discriminator中多了一个y,y在这里表示本次任务所依赖的condition。G在训练的时候要尽可能的靠近condition,D在辨识的时候同样需要参考condition的信息。

Objective

由于算法结构的不同,conditional GAN与一版GAN有着不同的loss函数。

conditional GAN的loss函数:

L_{cGAN}\left ( G,D \right ) = E_{x,y}\left [ logD\left ( x,y \right ) \right ] +E_{x,z}\left [ log\left ( 1-D\left ( x,G\left ( x,z \right ) \right ) \right ) \right ]

一般GAN的loss函数:

L_{GAN}\left ( G,D \right ) = E_{y}\left [ logD\left ( y \right ) \right ] +E_{z}\left [ log\left ( 1-D\left ( G\left ( z \right ) \right ) \right ) \right ]

这两者的不同正如前面所提到的,conditional GAN关注condition,一般GAN则没有这么做。另外,作者提到GAN的objective混入传统的L1 loss或者L2 loss是对结果有所提升的。引入L1/L2 loss,discriminator的任务没有发生变化,但是Generator的任务即包括欺骗discriminator又包含了尽可能向ground truth靠近,换句话说Generator在满足GAN的基础上还必须满足Conditional GAN。作者通过实验证明L_{cGAN}\left ( G,D \right ) 引入一个L1 loss后效果最为理想生成的图像模糊情况较轻,论文中是这样写的:using L1 distance rather than L2 as L1 encourages less blurring。

所以得到最终的loss函数为

L^{*}= L_{cGAN}\left ( G,D \right ) + L_{L1}\left ( G \right )

L_{L1}\left ( G \right ) = E_{x,y,z}\left [ \left \| y-G\left ( x,z \right ) \right \| \right}_{1}]

补充说明:z一般代表高斯噪声,为什么要在输入中加入高斯噪声,作者是这样解释的

Without z, the net could still learn a mapping from x to y, but would produce deterministic outputs, and therefore fail to match any distribution other than a delta function. 

如果不引入z,网络依然可以学习从x到y的映射,但这样只会产生确定性的输出,这会导致除三角函数以外,
无法匹配任何分布,因为这种映射关系是多种多样的,如果模型的输出是确定性的换句话说就是唯一的,
那么这个模型是无法广泛适用的。

另外,作者提到在实际测试过程中,虽然引入了z但是z往往会被generator给忽略掉,为了解决这个问题,作者以dropout形式在一些层中引入噪声z,但是作者提到尽管存在噪声损失,在网络的输出中也仅观察到较小的随机性,设计产生高随机输出的Conditional GAN,从而捕获其建模的条件分布的全部熵,这是当前遗留的一个重要问题。总结就是引入dropout噪声有用但不是最优解,这个问题还有待解决。

网络结构

在论文中作者有提到,Generator是一个U-NET结构的网络,Discriminator是一个patch-GAN结构的网络。

Generator 为什么要是一个U-NET结构,作者在论文中是这样说的:

1、图像到图像的翻译问题有一个必须要满足的特征,从高分辨率的输入映射到高分辨率的输出
2、虽然输入和输入图像的外观不同,但是它们共享一些基本的特征,所以输入图像和输出图像要满足大概对齐
3、有大量low-level的特征在输入与输出之间是共享的,比如在BW-RGB任务中,图像中物体的边缘特征,这些特征要能够在网络中直接传递而不必通过所有的层。

正是基于这些条件,作者设计了U-NET网络结构而不是encoder-decoder模型,U-NET模型中的skip connections让输入图像的low-level特征能够直接传递到输出图像中,如果一层层的传递没有skip connections,这些low-level的特征会随着网络层数的加深而逐渐消失,skip connections保证了最后的输出图像即capture输入图像的low-level特征又caputre到high-level特征,满足设计的初衷。

关于skip connections的方式,作者是这样定义的:

Specifically, we add skip connections between each layer i and layer n − i, 
where n is the total number of layers. Each skip connection simply concatenates 
all channels at layer i with those at layer n − i.


在第i层和第n-i层之间加入skip connection,注意n是Generator网络的总层数,
然后在第n-i层上把第i层传递过来的特征和第n-i层上原始的特征进行一个concatenate,
注意是在channel维度上进行连接,这就必须保证传入的特征和原始特征的H,
W必须保持一致,出现不一致的问题可以通过计算此时的特征(H,W)然后进行调整。

Discriminator是类似马尔可夫链的PatchGAN模型,作者提出这个模型是基于L1 loss已经针对低频部分进行了校正,所以只需要关注图像中的高频部分即可。为了只针对高频部分建模,把图像划分成大量N*N的小块然后关注这些小块的特征是一个非常高校的方式。作者把这种discriminator的结构叫做PatchGAN,模型在整个图像上做卷积然后判断每个N*N的小块是real还是false,进而得到一个矩阵,输出矩阵的平均值作为最终的输出。至于为什么说是类似马尔可夫模型,是因为假设距离大于等于一个图像块直径的两个像素相互独立这种鉴别器(Discriminator)就可以有效地将图像建模为马尔可夫随机场。另外,N较小时由于每个图像块参数比较少所以模型会跑的比较快,而且可以直接在较大的图像中应用这个方法,也是比较高效的。

第三章 算法实现

论文中提到的小技巧

1、不直接训练Generator,通过把Generator和Discriminator(冻结,不更新参数) concate到一起,
然后训练这个连接模型达到训练Generator的目的

2、D 的loss值要除以2降低D根据G调整的速率

3、使用Adam优化方法,初始学习率设置为0.0002,momentum 参数 beta1=0.5,beta2=0.999

等博主有时间单独说明一下各种优化函数以及它们的推导公式。

对于Generator以及PatchGAN来说,可以把它们看成一个函数,这个函数既有输入又有输出,然后我们可以根据实际需求设计函数的输入和输出形式。

对于Generator,定义一个generator函数,输入是(x,z)x表示condition,z表示噪声(在这里以dropout的形式引入),输出是这个函数根据x(以下用img_B表示x)生成的“假图”这里用fake_A表示,因为在代码中我是这样表示的,这里用同样的表示方式便于解释代码和算法过程,img_A表示与img_B对应的真实图像也即为generator要靠近的对象。前文提到过generator是U-NET结构,所谓U-NET结构简单来讲就是几个downsample加上几个upsample以及跨层的skip connections,结合代码更好理解:

import keras
from keras.layers import Conv2D,Input,Dropout,Concatenate,BatchNormalization,Activation,ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D
from keras.models import Model

def generator(img_shape):
    def conv2d(layer_input,filters,f_size=4,bn=True):
        d = Conv2D(filters,kernel_size=f_size,strides=2,padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d
    def deconv2d(layer_input,skip_input,filters,f_size=4,dropout_rate=0):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters,kernel_size=f_size,strides=1,padding='same',activation='relu')(u)
        if dropout_rate:
          u = Dropout(dropout_rate)(u)
        u = BatchNormalization(momentum=0.8)(u)
        u = Concatenate()([u,skip_input])
        return u
    d0 = Input(shape=img_shape)
    # downsampling 过程
    d1 = conv2d(d0,64,bn=False)
    d2 = conv2d(d1,64*2)
    d3 = conv2d(d2,64*4)
    d4 = conv2d(d3,64*8)
    d5 = conv2d(d4,64*8)
    d6 = conv2d(d5,64*8)
    d7 = conv2d(d6,64*8)
    # upsampling 过程
    u1 = deconv2d(d7,d6,64*8)
    u2 = deconv2d(u1,d5,64*8,dropout_rate=0.4)
    u3 = deconv2d(u2,d4,64*8)
    u4 = deconv2d(u3,d3,64*4,dropout_rate=0.4)
    u5 = deconv2d(u4,d2,64*2)
    u6 = deconv2d(u5,d1,64)
    # u6输出的shape是(128,128,64),所以为了输出与输入shape一致的结果
    # 即fake_A,需要进行上采样以及卷积操作,即u7和output_img
    u7 = UpSampling2D(size=2)(u6)
    output_img = Conv2D(3,kernel_size=4,strides=1,padding='same',activation='tanh')(u7)
    return Model(inputs=d0,outputs=[output_img])

if __name__ == '__main__':
    G = generator((256,256,3))
    G.summary()

参考上面的代码,在downsampling阶段,一般过程都是卷积、Leaky-ReLU、BatchNormalization,在upsampling阶段一般过程是先UpSamping,然后卷积,其次在某些层引入dropout噪声,之后进行BatchNormalization,最后通过skip connections进行concatenate(在channel维度进行,所以channel会增加)。下面是Generator模型每一层的一些特征,可以更好的理解模型的工作原理。

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 64) 3136        input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 64, 64, 128)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64, 64, 128)  512         leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 32, 32, 256)  524544      batch_normalization_1[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 32, 32, 256)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 32, 32, 256)  1024        leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 16, 16, 512)  2097664     batch_normalization_2[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 16, 16, 512)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 16, 16, 512)  2048        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 8, 8, 512)    4194816     batch_normalization_3[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 8, 8, 512)    0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 8, 8, 512)    2048        leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 4, 4, 512)    4194816     batch_normalization_4[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 4, 4, 512)    0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 4, 4, 512)    2048        leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 2, 2, 512)    4194816     batch_normalization_5[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 2, 2, 512)    0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 2, 2, 512)    2048        leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 4, 4, 512)    0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 4, 4, 512)    4194816     up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 4, 4, 512)    2048        conv2d_8[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 4, 4, 1024)   0           batch_normalization_7[0][0]      
                                                                 batch_normalization_5[0][0]      
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 8, 8, 1024)   0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 512)    8389120     up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 8, 8, 512)    2048        conv2d_9[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 8, 8, 1024)   0           batch_normalization_8[0][0]      
                                                                 batch_normalization_4[0][0]      
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 16, 16, 1024) 0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 512)  8389120     up_sampling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 16, 16, 512)  2048        conv2d_10[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 16, 16, 1024) 0           batch_normalization_9[0][0]      
                                                                 batch_normalization_3[0][0]      
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 32, 32, 1024) 0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  4194560     up_sampling2d_4[0][0]            
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 32, 32, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 32, 32, 512)  0           batch_normalization_10[0][0]     
                                                                 batch_normalization_2[0][0]      
__________________________________________________________________________________________________
up_sampling2d_5 (UpSampling2D)  (None, 64, 64, 512)  0           concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 64, 64, 128)  1048704     up_sampling2d_5[0][0]            
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 64, 64, 256)  0           batch_normalization_11[0][0]     
                                                                 batch_normalization_1[0][0]      
__________________________________________________________________________________________________
up_sampling2d_6 (UpSampling2D)  (None, 128, 128, 256 0           concatenate_5[0][0]              
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 128, 128, 64) 262208      up_sampling2d_6[0][0]            
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 128, 128, 64) 256         conv2d_13[0][0]                  
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 128, 128, 128 0           batch_normalization_12[0][0]     
                                                                 leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
up_sampling2d_7 (UpSampling2D)  (None, 256, 256, 128 0           concatenate_6[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 256, 256, 3)  6147        up_sampling2d_7[0][0]            
==================================================================================================
Total params: 41,843,331
Trainable params: 41,834,499
Non-trainable params: 8,832
__________________________________________________________________________________________________

同样的对于Discriminator,定义一个discriminator函数,这个函数的输入是[img_A/fake_A,img_B],即真实的[图像-condition]或者‘Generator生成的图像-condition’,输出是一个256/2**4的矩阵,即16*16的矩阵,这里假设输入图像的shape是(256,256,3),2**4是discriminator下采样的倍数。换一种说法解释这个过程,在经过层层卷积之后的图像中的每一个像素点对应输入图像中16*16的区域,而这个点的属性fake或者real则对应输入图像中16*16区域的属性fake或者real,体现了patchGAN的概念。结合代码更好理解:

import keras
from keras.layers import Input,Concatenate,Conv2D
from keras.models import Model
img_A = Input(shape=(256,256,3))
img_B = Input(shape=(256,256,3))
combined = Concatenate(axis=-1)([img_A,img_B])
def d_layer(layer_input,filters,f_size=4,bn=True):
    d = Conv2D(filters,kernel_size=f_size,strides=2,padding='same')(layer_input)
    return d 
if __name__ == '__main__':
    d1 = d_layer(combined,64,bn=False)
    d2 = d_layer(d1,64*2)
    d3 = d_layer(d2,64*4)
    d4 = d_layer(d3,64*8)
    valid = Conv2D(1,kernel_size=4,strides=1,padding='same')(d4)
    model = Model(inputs=[img_A,img_B],outputs=[valid])
    model.summary()

当卷积中的padding设置为‘same'时,卷积后输出shape就等于原来的shape除以strides,用公式表示:

Nn = N / strides

代码中的valid即是PatchGAN最后输出的fake/real矩阵,shape是(16,16,1)下面是每一层的输入输出情况:

Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_13 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
input_14 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 256, 256, 6)  0           input_13[0][0]                   
                                                                 input_14[0][0]                   
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 128, 128, 64) 6208        concatenate_7[0][0]              
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 64, 64, 128)  131200      conv2d_26[0][0]                  
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 32, 32, 256)  524544      conv2d_27[0][0]                  
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 16, 16, 512)  2097664     conv2d_28[0][0]                  
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 16, 16, 1)    8193        conv2d_29[0][0]                  
==================================================================================================
Total params: 2,767,809
Trainable params: 2,767,809
Non-trainable params: 0
__________________________________________________________________________________________________

前文有提到作者在论文中提及的训练技巧,对于Generator需要和Discriminator进行一个combine进行训练,这其实就是一个函数调用的问题,直接看代码:

# img_A 与condition img_B 对应的真正图像
#img_B condition

img_A = Input(shape=(256,256,3))
img_B = Input(shape=(256,256,3))
fake_A = generator(img_B)
discriminator.trainable = False # 冻结discriminator
valid = discriminator([fake_A,img_B])
combined = Model(inputs=[img_A,img_B],outputs=[valid,fake_A])
combined.compile(loss=['mse','mae'],loss_weights=[1,100],optimizer=optimizer)

由于combined是由两个函数连接起来的,所以它的输入是两个函数的输入,输出是两个函数的输出,Discriminator的输出计算loss时使用’mse‘,'mae'是前面提到过的L1 loss,由于连接模型主要是为了训练generator,所以给L1 loss的权重设置位一个较大的数,这里是100。

关于batchsize的设置,论文中推荐是1~10,这个根据硬件条件进行调整,过大会爆显存,过小训练时间会长。另外,如果使用的是train_on_batch 模式,输入(X,Y),X是[img_A,img_B]来源于真实训练样本,Y是[valid],这个是一个矩阵,值全为0或全为1,训练discriminator时,如果输入是[img_A,img_B],则valid是一个全1矩阵,shape是(batchsize,16,16,1),如果输入时[fake_A,img_B],则valid是一个全0矩阵,shape是(batchsize,16,16,1),而通过combined训练generator时,输入是[img_A,img_B],Y是[valid,img_A]此时valid是(batchsize,16,16,1)的全1矩阵。

关于epoch的设置,博主在实验时发现20个epochs之后,模型趋于稳定。

图四 结果展示

如果有不清楚的或者有疑问的可以私信或者在评论区讨论。完整的代码暂不开放。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

nobrody

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值