Keras搭建CycleGAN

Keras搭建CycleGAN

1. 原理

参考:CycleGAN原理

2. 数据准备

2.1 数据下载

  1. 斑马to黄种马的数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
  2. 苹果to橘子数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
  3. 画作to照片数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/monet2photo.zip
  4. 地图数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/maps.zip

2.2 set_session设置

  • 由于set_session在最新版本中已经不存在,所以需要在头文件中添加
    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    sess = tf.compat.v1.Session()
    K.set_session(sess)
    
  • 在初始设置中需要添加:
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.experimental.set_virtual_device_configuration(gpus[0],
         [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=6000)])  # 数值根据显卡内存设定
    

2.3 Tensorflow/Keras 指定CPU运行

2.3.1 全局配置

运行TensorFlow代码时候常出现OOM(Out of Memory)的错误,原因是batch_size设置得太大导致显存不足。如果想让代码仅仅运行在CPU下,可在原代码中加入如下代码:

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

注:上述代码一定要放在import tensorflow或keras等之前,否则不起作用。

2.3.2 tf配置

通过tensorflow配置指定到cpu上运行

with tf.device('/cpu:0'):
	xxx

或者

config = tf.ConfigProto(device_count = {'CPU': 4}) # 分配cpu个数
with tf.Session(config=config) as sess:
	xxx

2.4 keras_contrib库的Windows安装

参考:好像还挺好玩的GAN7——CycleGAN实现图像风格转换

3. 网络构建

3.1. Generator

  • 生成网络的目标是:输入一张图片,转化成自己期望的风格的那张图片。
  • 生成器由三部分组成: 编 码 器 \color{red}编码器 转 换 器 \color{red}转换器 解 码 器 \color{red}解码器 。(也可以用U-net网络)
  • 建立一个build_generator.py文件
# author: HQR
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import layers
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization


def residual_block(input_layer, kernel_size, filter_num, block):
    # 残差网络的函数
    con_name_base = 'res' + block + '_branch'
    in_name_base = 'in' + block + '_branch'
    # 第一层
    x1 = ZeroPadding2D(padding=(1, 1))(input_layer)
    x1 = Conv2D(filters=filter_num, kernel_size=kernel_size, name=con_name_base + '2a')(x1)
    x1 = InstanceNormalization(axis=3, name=in_name_base + '2a')(x1)
    # 第二层
    x2 = ZeroPadding2D(padding=(1, 1))(x1)
    x2 = Conv2D(filters=filter_num, kernel_size=kernel_size, name=con_name_base + '2c')(x2)
    x2 = InstanceNormalization(axis=3, name=in_name_base + '2c')(x2)
    # 残差
    x = layers.add([x2, input_layer])
    x = Activation('relu')(x)
    return x


def encoded(layer_input, filters, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=False):
    if upsampling2d:
        layer_input = UpSampling2D((2, 2))(layer_input)

    x = ZeroPadding2D(padding=pad_size)(layer_input)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
    return x


def build_generator(input_height, input_width, channel):

    img_input = Input(shape=(input_height, input_width, channel))
    # 第一步:编码
    # 128,128,3 ->  128,128,64
    g1 = encoded(img_input, filters=64, pad_size=(3, 3), kernel_size=(7, 7), strides=1)
    # 128,128,64 -> 64,64,128
    g1 = encoded(g1, filters=128, pad_size=(1, 1), kernel_size=(3, 3), strides=2)
    # 64,64,128 -> 32,32,256
    g1 = encoded(g1, filters=256, pad_size=(1, 1), kernel_size=(3, 3), strides=2)

    # 第二步: 转换器,残差网络
    for i in range(9):
        g1 = residual_block(g1, kernel_size=(3, 3), filter_num=256, block=str(i))

    # 第三步: 解码器
    # 32,32,256 -> 64,64,128
    g3 = encoded(g1, filters=128, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=True)
    # 64,64,128 -> 64,64,128 -> 128,128,64
    g3 = encoded(g3, filters=64, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=True)
    # 128,128,64 -> 128,128,3
    g3 = ZeroPadding2D(padding=(3, 3))(g3)
    img_output = Conv2D(channel, kernel_size=(7, 7), activation='tanh')(g3)

    return Model(img_input, img_output)

3.2 Discriminator

  • 判别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。
  • 判别器本身就属于卷积网络,需要从图像中提取特征;然后是确定这些特征是否属于该特定类别,使用一个产生一维输出的卷积层来完成这个任务。
  • Dicriminator的训练的loss函数使用的是LSGAN中所提到 均 方 差 \color{red}均方差 ,这种loss可以提高假图像的精度。
  • 最后卷积完后的shape为(8,8,1),利用了patch_GAN
  • 建立一个build_discriminator.py文件
# author: HQR
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization


def build_discriminator(input_height, input_width, channel):

    def conv2d(layer_input, filters, f_size=4, nomalization=True):
        d = Conv2D(filters=filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if nomalization:
            d = InstanceNormalization()(d)
        return d

    img_input = Input(shape=(input_height, input_width, channel))
    # 128,128,3 -> 64,64,64
    d1 = conv2d(img_input, 64, nomalization=False)
    # 64,64,64 -> 32,32,128
    d2 = conv2d(d1, 128)
    # 32,32,128 -> 16,16,256
    d3 = conv2d(d2, 256)
    # 16,16,256 -> 8,8,512
    d4 = conv2d(d3, 512)
    # 对每个像素点判断是否有效
    # 8,8,512 -> 8,8,1
    validity = Conv2D(filters=1, kernel_size=3, strides=1, padding="same")(d4)

    return Model(img_input, validity)

3.3 数据加载

  • 由于是要用2回load_batch,所以此处不用return,而使用yield
  • 建立一个data_loader.py文件
# author:HQR
import imageio
from skimage.transform import resize
from glob import glob
import numpy as np


class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_image = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_image:
            img = self.imread(img_path)
            if not is_testing:
                img = resize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.
        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))
        # 选择batch
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size
        # 选择数据
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)
        # 选择batch数据
        for i in range(self.n_batches - 1):
            batch_A = path_A[i*batch_size: (i+1)*batch_size]
            batch_B = path_B[i*batch_size: (i+1)*batch_size]
            imgs_A, imgs_B = [], []
            # zip打包成元组处理
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = resize(img_A, self.img_res)
                img_B = resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                    img_A = np.fliplr(img_A)
                    img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = resize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.float)

3.4 训练

  1. 初始化
    1. 创建两个生成模型,一个用于从图片风格A转换成图片风格B,一个用于从图片风格B转换成图片风格A。
    2. 创建两个判别模型,分别用于风格A图片的真伪判断和风格B图片的真伪判断。
    3. 判别模型的训练所用的损失函数与LSGAN相同,通过判断是否正确进行训练。
  2. 损失设定
    损失有以下6种:参考:CycleGAN原理
from __future__ import print_function, division
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from build_generator import *
from bulid_discriminator import *
from data_loader import *
from tensorflow.python.keras import backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os
# 由于GPU总是爆显存,关闭GPU,用CPU进行操作
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# 设置set_session,与GPU有关
sess = tf.compat.v1.Session()
K.set_session(sess)

class CycleGAN():
    def __init__(self):
        # 设置GPU,防止内存爆
        # config = tf.compat.v1.ConfigProto()
        # config.gpu_options.allocator_type = 'BFC'  # A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc.
        # config.gpu_options.per_process_gpu_memory_fraction = 0.8
        # config.gpu_options.allow_growth = True
        # K.set_session(tf.compat.v1.Session(config=config))
        # 输入图像大小128*128*3
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # 载入数据
        self.dataset_name = 'horse2zebra'
        self.data_loader = DataLoader(dataset_name= self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))
        # Calculate output shape of D (PatchGAN)
        # 因为Discriminator 引用了 PatchGAN 的思想
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # 设置参数
        # Loss weights
        self.lambda_cycle = 10.0  # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle  # Identity loss
        # 优化器参数
        optimizer = Adam(0.0002, 0.5)

        # -------------------------#
        #   建立判别网络
        # -------------------------#
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.d_A.summary()

        # -------------------------#
        #   建立判别网络
        # -------------------------#
        # 创建生成模型
        self.g_A2B = self.build_generator()
        self.g_B2A = self.build_generator()
        self.g_A2B.summary()

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # 生成假图片
        fake_B = self.g_A2B(img_A)
        fake_A = self.g_B2A(img_B)
        # 生成重建图片(reconstruction image)
        recon_A = self.g_B2A(fake_B)
        recon_B = self.g_A2B(fake_A)
        # 生成identity图片
        id_A = self.g_B2A(img_A)
        id_B = self.g_A2B(img_B)

        # -------------------------#
        #   将生成模型和判别模型结合,生成模型训练时候,训练时候不训练判别模型
        # -------------------------#
        self.d_A.trainable = False
        self.d_B.trainable = False
        # 评价是否为真
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)
        # 训练
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[valid_A, valid_B,
                                       recon_A, recon_B,
                                       id_A, id_B])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                              loss_weights=[1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id],
                              optimizer=optimizer)
        self.combined.summary()

    def build_generator(self):
        model = build_generator(self.img_rows, self.img_cols, self.channels)
        return model

    def build_discriminator(self):
        model = build_discriminator(self.img_rows, self.img_cols, self.channels)
        return model

    def scheduler(self, models, epoch):
        # 每隔100个epoch,学习率减小为原来的1/2
        if epoch % 20 == 0 and epoch != 0:
            for model in models:
                lr = K.get_value(model.optimizer.lr)
                K.set_value(model.optimizer.lr, lr * 0.5)
            print("lr changed to {}".format(lr * 0.5))

    def train(self, init_epoch, epochs, batch_size=1, sample_interval=50):
        start_time = datetime.datetime.now()

        # 用到patchGAN
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        if init_epoch != 0:
            self.d_A.load_weights("weight/%s/d_A_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)
            self.d_B.load_weights("weight/%s/d_B_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)
            self.g_A2B.load_weights("weight/%s/g_A2B_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)
            self.g_B2A.load_weights("weight/%s/g_B2A_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)

        for epoch in range(init_epoch, epochs):
            self.scheduler([self.combined, self.d_A, self.d_B], epoch)
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
                # ------------------ #
                #  训练生成模型
                # ------------------ #
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])
                # ---------------------- #
                #  训练判别模型
                # ---------------------- #
                # 1. 假图片的Loss_D
                fake_B = self.g_A2B.predict(imgs_A)
                fake_A = self.g_B2A.predict(imgs_B)
                # 1.1 对于判别器A
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
                # 1.2 对于判别器B
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
                # 1.3 Loss_Discriminator
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                elapsed_time = datetime.datetime.now() - start_time
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, "
                      "id: %05f] time: %s " % (epoch, epochs, batch_i, self.data_loader.n_batches,
                                               d_loss[0], 100 * d_loss[1], g_loss[0], np.mean(g_loss[1:3]),
                                               np.mean(g_loss[3:5]), np.mean(g_loss[5:6]), elapsed_time))

                # 保存训练模型
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
                    if epoch % 5 == 0 and epoch != init_epoch:
                        os.makedirs('weight/%s' % self.dataset_name, exist_ok=True)
                        self.d_A.save_weights("weight/%s/d_A_epoch%d.h5" % (self.dataset_name, epoch))
                        self.d_B.save_weights("weight/%s/d_B_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_A2B.save_weights("weight/%s/d_A2B_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_B2A.save_weights("weight/%s/d_B2A_epoch%d.h5" % (self.dataset_name, epoch))


    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
        # Translate images to the other domain
        fake_B = self.g_A2B.predict(imgs_A)
        fake_A = self.g_B2A.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_B2A.predict(fake_B)
        reconstr_B = self.g_A2B.predict(fake_A)
        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

if __name__ =='__main__':
    gan = CycleGAN()
    gan.train(init_epoch=0, epochs=200, batch_size=1, sample_interval=200)

3.5 测试

  • 建立一个predict.py文件
from build_generator import *
from PIL import Image
import numpy as np
model = build_generator(None,None,3)
model.load_weights(r"weights\horse2zebra\g_B2A_epoch15.h5")# 根据自己训练的
# 图片根据自己需求选取
img = np.array(Image.open(r"datasets\horse2zebra\trainB\n02391049_32.jpg").resize([256,256]))/127.5 - 1
img = np.expand_dims(img,axis=0)
fake = (model.predict(img)*0.5 + 0.5)*255

face = Image.fromarray(np.uint8(fake[0]))
face.show()

参考

  1. 好像还挺好玩的GAN7——CycleGAN实现图像风格转换
  2. BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结
  3. Batch normalization和Instance normalization的对比?
  4. PatchGAN理解
  5. keras Conv2D参数详解
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值