基于Mxnet实现GAN-CycleGAN【附部分源码】


前言

本文基于Mxnet实现CycleGAN


一、CycleGAN是什么

CycleGAN图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移

  • 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
  • 机器学习的模型可大体分为两类,生成模型(Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来预测 。生成模型是给定某种隐含信息,来随机产生观测数据。举个简单的例子:
  • 生成模型:给一系列猫的图片,生成一张新的猫咪(不在数据集里)
  • 判别模型:给定一张图,判断这张图里的动物是猫还是狗

二、代码实现

1.引入库

import random, os, cv2, time
import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon, image, autograd
from mxnet.gluon.data.vision import transforms
from mxnet.base import numeric_types
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from mxboard import SummaryWriter

2.网络构建

def define_G(output_nc, ngf, which_model_netG, use_dropout=False):
    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=9)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=6)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(output_nc, 7, ngf, use_dropout=use_dropout)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(output_nc, 8, ngf, use_dropout=use_dropout)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % opt.which_model_netG)

    return netG

def define_D(ndf, which_model_netD, n_layers_D=3, use_sigmoid=False):
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(ndf, n_layers=3, use_sigmoid=use_sigmoid)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(ndf, n_layers_D, use_sigmoid=use_sigmoid)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD)
    return netD

3.数据加载器

class DataSet(gluon.data.Dataset):
    def __init__(self,DataDir_A, DataDir_B, transform):
        self.A_paths = [os.path.join(DataDir_A,f) for f in os.listdir(DataDir_A)]
        self.B_paths = [os.path.join(DataDir_B,f) for f in os.listdir(DataDir_B)]
        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        self.transform = transform

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        B_path = self.B_paths[index % self.B_size]
        A_img = image.imread(A_path)
        B_img = image.imread(B_path)
        A = self.transform(A_img)
        B = self.transform(B_img)
        return A, B

    def __len__(self):
        return max(self.A_size, self.B_size)

4.模型训练

1.优化器设置

optimizer_GA = gluon.Trainer(self.netG_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_GB = gluon.Trainer(self.netG_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DA = gluon.Trainer(self.netD_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DB = gluon.Trainer(self.netD_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')

2.损失函数定义

cyc_loss = gluon.loss.L1Loss()

3.循环训练

for i, (real_A, real_B) in enumerate(self.data_loader):
    real_A = gluon.utils.split_and_load(real_A, ctx_list=self.ctx, batch_axis=0)
    real_B = gluon.utils.split_and_load(real_B, ctx_list=self.ctx, batch_axis=0)
    loss_G_list = []
    loss_D_A_list = []
    loss_D_B_list = []
    fake_A_list = []
    fake_B_list = []
    losses_log.reset()
    with autograd.record():
        for A,B in zip(real_A,real_B):
            fake_B = self.netG_A(A)
            rec_A = self.netG_B(fake_B)
            fake_A = self.netG_B(B)
            rec_B = self.netG_A(fake_A)

            idt_A = self.netG_A(B)
            loss_idt_A = cyc_loss(idt_A,B) * 10.0 * 0.5
            idt_B = self.netG_B(A)
            loss_idt_B = cyc_loss(idt_B,A) * 10.0 * 0.5

            loss_G_A = self.gan_loss(self.netD_A(fake_B),True)
            loss_G_B = self.gan_loss(self.netD_B(fake_A),True)
            loss_cycle_A = cyc_loss(rec_A,A) * 10.0
            loss_cycle_B = cyc_loss(rec_B,B) * 10.0
            loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B

            loss_G_list.append(loss_G)
            fake_A_list.append(fake_A)
            fake_B_list.append(fake_B)
            losses_log.add(loss_G_A=loss_G_A, loss_cycle_A=loss_cycle_A, loss_idt_A=loss_idt_A,loss_G_B=loss_G_B,
                        loss_cycle_B=loss_cycle_B, loss_idt_B=loss_idt_B,real_A=A, fake_B=fake_B, rec_A=rec_A,
                        idt_A=idt_A, real_B=B, fake_A=fake_A, rec_B=rec_B,idt_B=idt_B)
        autograd.backward(loss_G_list)
    optimizer_GA.step(self.batch_size)
    optimizer_GB.step(self.batch_size)
    with autograd.record():
        for A,B,fake_A,fake_B in zip(real_A,real_B,fake_A_list,fake_B_list):
            fake_B_tmp = fake_B_pool.query(fake_B)
            pred_real = self.netD_A(B)
            loss_D_real = self.gan_loss(pred_real,True)
            pred_fake = self.netD_A(fake_B_tmp.detach())
            loss_D_fake = self.gan_loss(pred_fake, False)
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A_list.append(loss_D_A)

            fake_A_tmp = fake_A_pool.query(fake_A)
            pred_real = self.netD_B(A)
            loss_D_real = self.gan_loss(pred_real, True)
            pred_fake = self.netD_B(fake_A_tmp.detach())
            loss_D_fake = self.gan_loss(pred_fake,False)
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B)
            losses_log.add(loss_D_A=loss_D_A,loss_D_B=loss_D_B)
        autograd.backward(loss_D_A_list + loss_D_B_list)
    optimizer_DA.step(self.batch_size)
    optimizer_DB.step(self.batch_size)
    if ((epoch-1) * len(self.data_loader) + i) % 1 == 0 and self.sw is not None:
        plot_loss(losses_log, (epoch-1) * len(self.data_loader) + i,epoch,i, self.sw)
        plot_img(losses_log, self.sw)

4.模型保存

self.netG_A.save_parameters(os.path.join(ModelPath, 'netG_A.dat'))
self.netG_B.save_parameters(os.path.join(ModelPath, 'netG_B.dat'))
self.netD_A.save_parameters(os.path.join(ModelPath, 'netD_A.dat'))
self.netD_B.save_parameters(os.path.join(ModelPath, 'netD_B.dat'))

5.模型预测

 def predict(self,cv_img,ATOB=True):
    img_origin = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
    start_time = time.time()
    img = nd.array(img_origin)
    img = self.transform_fn(img)
    img = img.expand_dims(0).as_in_context(self.ctx)
    with autograd.record():
        if ATOB:
            output = self.netG_A(img)
        else:
            output = self.netG_B(img)
        predict = mx.nd.squeeze(output)
        predict = ((predict.transpose([1,2,0]).asnumpy() * 0.5 + 0.5) * 255).clip(0, 255).astype('uint8')
    res_image = cv2.cvtColor(predict,cv2.COLOR_BGR2RGB)
    result_value = {
        "image_result": res_image,
        "time": (time.time() - start_time) * 1000
    }
    return result_value

三、函数主入口

本人的代码调用比较简单

if __name__ == '__main__':
    ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
    ctu.InitModel(DataDir_A='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainA',
                  DataDir_B='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainB', 
                  channels=3,batch_size = 1,num_workers = 0, channels_rate=0.5)
    ctu.train(TrainNum=300, learning_rate=0.0001,lr_decay_epoch='50,100,150,200',lr_decay = 0.9,ModelPath='./Model', logDir = './logs')


    ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
    ctu.LoadModel(ModelPath=['./Model/netG_A.dat','./Model/netG_B.dat','./Model/netD_A.dat','./Model/netD_B.dat'])
    cv2.namedWindow("origin", 0)
    cv2.resizeWindow("origin", 640, 480)
    cv2.namedWindow("result", 0)
    cv2.resizeWindow("result", 640, 480)
    for root, dirs, files in os.walk(r'D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/testA'):
        for f in files:
            img_cv = cv2.imread(os.path.join(root, f))
            if img_cv is None:
                continue
            res = ctu.predict(img_cv,ATOB=True)
            print("耗时:" + str(res['time']) + ' ms')
            cv2.imshow("origin", img_cv)
            cv2.imshow("result", res['image_result'])
            cv2.waitKey()

四、训练效果展示

在这里插入图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱学习的广东仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值