CycleGAN训练教程

目录

一、项目下载链接 

二、CycleGAN概述

三、CycleGAN原理

四、CycleGAN的应用场景

 五、训练过程

       (1)代码内容

       (2) 环境配置

        (3)预训练权重下载

       (4)下载训练数据

       (5)参数设置

        (6)训练操作

        (7)训练界面

        (8)训练结果

六、测试操作

       (1) 测试命令

        (2)测试结果

七、相关链接


目录

项目下载链接 

CycleGAN概述

CycleGAN原理

CycleGAN的应用场景

 训练过程

       代码内容

        环境配置

        预训练权重下载

        下载训练数据

        参数设置

        训练操作

        训练界面

        训练结果

测试操作

        测试命令

        测试结果

相关链接


一、项目下载链接 

       基于 PyTorch 的图像生成模型,包含了 CycleGAN 和 pix2pix 两种模型,适合用于实现图像生成和风格迁移等任务。 

论文:https://arxiv.org/pdf/1703.10593.pdf 

代码:junyanz/pytorch-CycleGAN-and-pix2pix首页 - GitCodeicon-default.png?t=N7T8https://gitcode.com/junyanz/pytorch-CycleGAN-and-pix2pix/overview           junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch (github.com)icon-default.png?t=N7T8https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

二、CycleGAN概述

        在输入图像和输出图像之间,使用对齐图像对的训练集。但是,对于许多任务,配对训练数据将不可用。我们提出一个学习从源翻译图像的方法域 X 到目标域 Y,在没有配对的情况下例子。我们的目标是学习映射 G : X → Y使得来自 G(X) 的图像分布与使用对抗性损失的分布 Y 无法区分。由于此映射的约束非常不足,因此我们将其与逆映射 F : Y → X ,并引入强制执行 F(G(X)) ≈ X 的循环一致性损失(反之亦然)。在几个任务上呈现了定性结果不存在配对训练数据的地方,包括集合风格转移、对象变形、季节转移、照片增强等定量比较。

三、CycleGAN原理

(a) 我们的模型包含两个映射函数 G : X → Y 和 F : Y → X,以及相关的对抗函数鉴别器 DY 和 DX。DY 鼓励 G 将 X 转换为与域 Y 无法区分的输出,反之亦然用于 DX 和 F。为了进一步规范映射,我们引入了两个循环一致性损失函数,即如果我们从一个领域转换到另一个领域,然后再转换回来:(b)前向损失函数:x → G(x) → F(G(x)) ≈ x,以及 (c) 后向损失函数:y → F(y) → G(F(y)) ≈ y

四、CycleGAN的应用场景

 五、训练过程

        (以马<——>斑马的训练过程为例)

       (1)代码内容

       

       (2) 环境配置

#操作命令:
pip install -r requirements.txt

         requirements.txt内容如下:

torch>=1.4.0
torchvision>=0.5.0
dominate>=2.4.0
visdom>=0.1.8.8
wandb

        torch和torchvision安装建议直接官网配置:PyTorch

        (3)预训练权重下载

        进入到根目录pytorch-CycleGAN-and-pix2pix-master/下面:

        运行命令:

bash ./scripts/download_cyclegan_model.sh horse2zebra

        下载到.//checkpoints/horse2zebra_pretrained目录下面

        

        也可以网页直接下载:Index of /cyclegan/pretrained_models (berkeley.edu)

       (4)下载训练数据

        运行命令:

bash ./datasets/download_cyclegan_dataset.sh maps

         下载到:./datasets/horse2zebra目录下面

        

        也可以直接网页下载:打开./datasets路径下的download_cyclegan_dataset文件查看到数据下载路径为:Index of /cyclegan/datasets

       (5)参数设置

        ./options目录下

        基础参数配置:base_options.py文件

    def initialize(self, parser):
        """Define the common options that are used in both training and test."""
        # basic parameters
        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--gpu_ids', type=str, default='0,5,6,7', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        # model parameters
        parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
        parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
        # dataset parameters
        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
        parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
        parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
        parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
        # additional parameters
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
        self.initialized = True
        return parser

        训练参数设置:train_options.py


    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # visdom and HTML visualization parameters
        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        # network saving and loading parameters
        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        # training parameters                     ##总共epoch niter+niter_decay=200轮
        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

        测试参数配置:test_options.py

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        # Dropout and Batchnorm has different behavioir during training and test.
        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
        # rewrite devalue values
        parser.set_defaults(model='test')
        # To avoid cropping, the load_size should be the same as crop_size
        parser.set_defaults(load_size=parser.get_default('crop_size'))

        (6)训练操作

        命令:

python train.py --dataroot ./datasets/horse2zebra --name maps_cyclegan --model cycle_gan

        (7)训练界面

        (8)训练结果

        保存在:./checkpoints/maps_cyclegan文件下面

        

六、测试操作

       (1) 测试命令

python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

        

        (2)测试结果

         在路径.results\maps_cyclegan\test_latest_maps下可以看到测试结果

        

         real为原图,fake为具有原图风格的假图,rec为根据假图复原原图的图。

七、相关链接

用自己的数据集实战CycleGAN_cyclegan训练-CSDN博客

  • 36
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用thchs30数据集训练CycleGAN模型实现语音转换并封装起来的代码。需要注意的是,该代码仅供参考,需要根据具体情况进行修改和调整。 ```python # 导入必要的库 import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from cycle_gan import CycleGAN from thchs30_dataset import Thchs30Dataset # 设置超参数 batch_size = 16 num_workers = 4 learning_rate = 0.0002 num_epochs = 200 lambda_cycle = 10 lambda_identity = 5 # 设置数据集路径 data_dir = "thchs30/" train_dir_A = os.path.join(data_dir, "train/A/") train_dir_B = os.path.join(data_dir, "train/B/") test_dir_A = os.path.join(data_dir, "test/A/") test_dir_B = os.path.join(data_dir, "test/B/") # 创建数据集和数据加载器 train_dataset_A = Thchs30Dataset(train_dir_A) train_dataset_B = Thchs30Dataset(train_dir_B) test_dataset_A = Thchs30Dataset(test_dir_A) test_dataset_B = Thchs30Dataset(test_dir_B) train_loader_A = DataLoader(train_dataset_A, batch_size=batch_size, shuffle=True, num_workers=num_workers) train_loader_B = DataLoader(train_dataset_B, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader_A = DataLoader(test_dataset_A, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader_B = DataLoader(test_dataset_B, batch_size=batch_size, shuffle=False, num_workers=num_workers) # 创建CycleGAN模型并定义优化器和损失函数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cycle_gan = CycleGAN().to(device) optimizer_G = optim.Adam(cycle_gan.generator.parameters(), lr=learning_rate, betas=(0.5, 0.999)) optimizer_D_A = optim.Adam(cycle_gan.discriminator_A.parameters(), lr=learning_rate, betas=(0.5, 0.999)) optimizer_D_B = optim.Adam(cycle_gan.discriminator_B.parameters(), lr=learning_rate, betas=(0.5, 0.999)) criterion_GAN = nn.MSELoss().to(device) criterion_cycle = nn.L1Loss().to(device) criterion_identity = nn.L1Loss().to(device) # 训练CycleGAN模型 for epoch in range(num_epochs): cycle_gan.train() for batch_idx, (real_A, real_B) in enumerate(zip(train_loader_A, train_loader_B)): real_A = real_A.to(device) real_B = real_B.to(device) # 训练生成器G optimizer_G.zero_grad() # 计算生成的B以及重构的A fake_B = cycle_gan.generator(real_A) cycle_A = cycle_gan.generator(fake_B) cycle_B = cycle_gan.generator(real_B) # 计算生成的A以及重构的B fake_A = cycle_gan.generator(real_B) cycle_B = cycle_gan.generator(fake_A) cycle_A = cycle_gan.generator(real_A) # 计算对抗损失 pred_fake_A = cycle_gan.discriminator_A(fake_A) pred_real_A = cycle_gan.discriminator_A(real_A) loss_GAN_A = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A).to(device)) pred_fake_B = cycle_gan.discriminator_B(fake_B) pred_real_B = cycle_gan.discriminator_B(real_B) loss_GAN_B = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B).to(device)) # 计算循环一致性损失 loss_cycle_A = criterion_cycle(cycle_A, real_A) * lambda_cycle loss_cycle_B = criterion_cycle(cycle_B, real_B) * lambda_cycle # 计算身份损失 loss_identity_A = criterion_identity(cycle_gan.generator(real_A), real_A) * lambda_identity loss_identity_B = criterion_identity(cycle_gan.generator(real_B), real_B) * lambda_identity # 计算生成器总损失 loss_G = loss_GAN_A + loss_GAN_B + loss_cycle_A + loss_cycle_B + loss_identity_A + loss_identity_B loss_G.backward() optimizer_G.step() # 训练判别器A optimizer_D_A.zero_grad() pred_real_A = cycle_gan.discriminator_A(real_A) pred_fake_A = cycle_gan.discriminator_A(fake_A.detach()) loss_D_real_A = criterion_GAN(pred_real_A, torch.ones_like(pred_real_A).to(device)) loss_D_fake_A = criterion_GAN(pred_fake_A, torch.zeros_like(pred_fake_A).to(device)) loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5 loss_D_A.backward() optimizer_D_A.step() # 训练判别器B optimizer_D_B.zero_grad() pred_real_B = cycle_gan.discriminator_B(real_B) pred_fake_B = cycle_gan.discriminator_B(fake_B.detach()) loss_D_real_B = criterion_GAN(pred_real_B, torch.ones_like(pred_real_B).to(device)) loss_D_fake_B = criterion_GAN(pred_fake_B, torch.zeros_like(pred_fake_B).to(device)) loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5 loss_D_B.backward() optimizer_D_B.step() # 每个epoch结束后计算测试集上的损失和准确率 cycle_gan.eval() with torch.no_grad(): test_loss = 0.0 for real_A, real_B in zip(test_loader_A, test_loader_B): real_A = real_A.to(device) real_B = real_B.to(device) fake_B = cycle_gan.generator(real_A) cycle_A = cycle_gan.generator(fake_B) cycle_B = cycle_gan.generator(real_B) fake_A = cycle_gan.generator(real_B) cycle_B = cycle_gan.generator(fake_A) cycle_A = cycle_gan.generator(real_A) loss_cycle_A = criterion_cycle(cycle_A, real_A) * lambda_cycle loss_cycle_B = criterion_cycle(cycle_B, real_B) * lambda_cycle loss_identity_A = criterion_identity(cycle_gan.generator(real_A), real_A) * lambda_identity loss_identity_B = criterion_identity(cycle_gan.generator(real_B), real_B) * lambda_identity test_loss += loss_cycle_A.item() + loss_cycle_B.item() + loss_identity_A.item() + loss_identity_B.item() print("Epoch: {}, Test Loss: {:.6f}".format(epoch+1, test_loss)) # 保存模型 torch.save(cycle_gan.state_dict(), "cycle_gan.pth") ``` 以上代码中,`CycleGAN`类和`Thchs30Dataset`类都是需要自己实现的,可以参考CycleGAN和Thchs30数据集的论文和官方实现进行实现。最后,使用训练好的模型对语音进行转换的代码如下: ```python # 加载训练好的模型 cycle_gan = CycleGAN().to(device) cycle_gan.load_state_dict(torch.load("cycle_gan.pth")) # 定义转换函数 def convert(audio_path, output_path): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) audio = transform(audio_path) audio = audio.unsqueeze(0).to(device) fake_audio = cycle_gan.generator(audio) fake_audio = fake_audio.squeeze(0).cpu().detach().numpy() np.save(output_path, fake_audio) # 进行语音转换 audio_path = "input.wav" output_path = "output.npy" convert(audio_path, output_path) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值