StarGAN v2:多领域的不同图像合成

在这里插入图片描述
在这里插入图片描述

前言

相关介绍

在这里插入图片描述

StarGAN v2 是一种高级的生成对抗网络 (GAN) 架构,专门设计用于多域图像合成与转换。它是在 StarGAN 的基础上发展起来的,旨在解决多样性和可扩展性的问题。以下是关于 StarGAN v2 的详细介绍及其优缺点:

StarGAN v2 的工作原理

核心思想
  • Domain 和 Style 分离:StarGAN v2 将域(domain)和风格(style)的概念分离,允许用户独立控制图像的内容和风格。
  • Mapping Network:引入了一个映射网络(mapping network),负责将随机编码映射到不同的伪风格码。
  • Style Encoder:使用风格编码器(style encoder)来从真实的图像中获取风格码,这样可以确保生成的图像既具有多样性又保持了原有的风格细节。
主要组件
  • Generator:负责生成图像,能够处理多个域的转换。
  • Discriminator:用于判断生成图像的真实度,帮助训练过程更加稳定。
  • Mapping Network:将随机噪声映射到风格空间。
  • Style Encoder:从真实图像中提取风格信息。

优点

  1. 多样性:能够生成具有高度多样性的图像,避免了生成结果过于单一的问题。
  2. 可扩展性:能够在一个模型中处理多个不同的域,无需为每个域单独训练模型。
  3. 统一的框架:提供了统一的框架来处理不同的任务,如人脸属性编辑、动物种类转换等。
  4. 灵活性:用户可以独立控制图像的内容和风格,增加了使用的灵活性。
  5. 高质量输出:在多个基准测试上展示了高质量的图像生成能力。

缺点

  1. 复杂性:引入了多个组件(如映射网络和风格编码器),使得模型结构更加复杂,可能增加了训练难度。
  2. 训练资源需求:复杂的架构可能需要更多的计算资源和更长的训练时间。
  3. 泛化能力限制:虽然在特定数据集上表现出色,但可能在其他数据集或领域上的泛化能力有待验证。
  4. 标签信息限制:在使用多个数据集进行训练时,每个数据集只包含部分标签信息,这可能会影响某些任务的表现。

应用实例

  • 人脸属性编辑:例如改变人脸的性别、年龄等属性。
  • 动物种类转换:例如将猫的图像转换为狗的图像。
  • 其他图像转换任务:如季节变换、绘画风格转移等。

总结

StarGAN v2 是一个多域图像转换的前沿技术,它在多样性和可扩展性方面取得了显著的进步,使其成为图像合成领域的重要工具。不过,它的复杂性也可能带来一些挑战,尤其是在实际部署和训练过程中。

实验环境

python=3.6.7
pytorch=1.4.0
torchvision=0.5.0
ffmpeg=4.0.2 
opencv-python==4.1.2.30 
ffmpeg-python==0.2.0 
scikit-image==0.16.2
pillow==7.0.0 
scipy==1.2.1 
tqdm==4.43.0 
munch==2.5.0

项目地址

Linux

git clone https://github.com/clovaai/stargan-v2.git
cd stargan-v2

Windows

请到https://github.com/clovaai/stargan-v2.git网站下载源代码zip压缩包。

cd stargan-v2

项目结构

stargan-v2
├─assets
│  └─representative
│      ├─afhq
│      │  ├─ref
│      │  │  ├─cat
│      │  │  ├─dog
│      │  │  └─wild
│      │  └─src
│      │      ├─cat
│      │      ├─dog
│      │      └─wild
│      ├─celeba_hq
│      │  ├─ref
│      │  │  ├─female
│      │  │  └─male
│      │  └─src
│      │      ├─female
│      │      └─male
│      └─custom
│          ├─female
│          └─male
├─core
│  └─__pycache__
├─data
│  └─celeba_hq
│      ├─train
│      │  ├─female
│      │  └─male
│      └─val
│          ├─female
│          └─male
├─expr
│  ├─checkpoints
│  │  └─celeba_hq
│  └─results
│      └─celeba_hq
└─metrics
    └─__pycache__

具体用法

数据集和预训练网络

  • 源码提供了一个脚本来下载StarGAN v2和相应的预训练网络中使用的数据集。数据集和网络检查点将被下载并分别存储在data和expr/检查点目录中。
  • CelebA-HQ。要下载CelebA-HQ数据集和预训练网络,可以运行以下命令:
bash download.sh celeba-hq-dataset
bash download.sh pretrained-network-celeba-hq
bash download.sh wing
  • AFHQ。要下载AFHQ数据集和预训练网络,可以运行以下命令:
bash download.sh afhq-dataset
bash download.sh pretrained-network-afhq

进行评估(以celeba-hq为例)

使用 Fréchet Inception Distance (FID) and Learned Perceptual Image Patch Similarity (LPIPS),进行评估StarGAN v2。

# celeba-hq
python main.py --mode eval --num_domains 2 --w_hpf 1 --resume_iter 100000 --train_img_dir data/celeba_hq/train --val_img_dir data/celeba_hq/val --checkpoint_dir expr/checkpoints/celeba_hq --eval_dir expr/eval/celeba_hq

# afhq
python main.py --mode eval --num_domains 3 --w_hpf 0 --resume_iter 100000 --train_img_dir data/afhq/train --val_img_dir data/afhq/val --checkpoint_dir expr/checkpoints/afhq --eval_dir expr/eval/afhq

在这里插入图片描述
在这里插入图片描述

  • 注意,评估指标是使用随机潜在向量或参考图像计算的,这两者都是由种子数选择的。

进行训练(以celeba-hq为例)

要从头开始训练StarGAN v2,运行以下命令。生成的映像和网络检查点将分别存储在expr/samples和expr/检查点目录中。在单个Tesla V100 GPU上训练大约需要三天。请参阅这里的训练参数和它们的描述。

# celeba-hq
python main.py --mode train --num_domains 2 --w_hpf 1 --lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 --train_img_dir data/celeba_hq/train --val_img_dir data/celeba_hq/val

# afhq
python main.py --mode train --num_domains 3 --w_hpf 0 --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 --train_img_dir data/afhq/train --val_img_dir data/afhq/val

在这里插入图片描述

进行测试(以celeba-hq为例)

新建一个test.py文件,内容如下。

import os
import argparse

from munch import Munch
from torch.backends import cudnn
import torch

from core.data_loader import get_train_loader
from core.data_loader import get_test_loader
from core.solver import Solver


def str2bool(v):
    return v.lower() in ('true')


def subdirs(dname):
    return [d for d in os.listdir(dname)
            if os.path.isdir(os.path.join(dname, d))]


def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model arguments
    parser.add_argument('--img_size', type=int, default=256,
                        help='Image resolution')
    parser.add_argument('--num_domains', type=int, default=2,
                        help='Number of domains')
    parser.add_argument('--latent_dim', type=int, default=16,
                        help='Latent vector dimension')
    parser.add_argument('--hidden_dim', type=int, default=512,
                        help='Hidden dimension of mapping network')
    parser.add_argument('--style_dim', type=int, default=64,
                        help='Style code dimension')

    # weight for objective functions
    parser.add_argument('--lambda_reg', type=float, default=1,
                        help='Weight for R1 regularization')
    parser.add_argument('--lambda_cyc', type=float, default=1,
                        help='Weight for cyclic consistency loss')
    parser.add_argument('--lambda_sty', type=float, default=1,
                        help='Weight for style reconstruction loss')
    parser.add_argument('--lambda_ds', type=float, default=1,
                        help='Weight for diversity sensitive loss')
    parser.add_argument('--ds_iter', type=int, default=100000,
                        help='Number of iterations to optimize diversity sensitive loss')
    parser.add_argument('--w_hpf', type=float, default=1,
                        help='weight for high-pass filtering')

    # training arguments
    parser.add_argument('--randcrop_prob', type=float, default=0.5,
                        help='Probabilty of using random-resized cropping')
    parser.add_argument('--total_iters', type=int, default=100000,
                        help='Number of total iterations')
    parser.add_argument('--resume_iter', type=int, default=0,
                        help='Iterations to resume training/testing')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size for training')  
    parser.add_argument('--val_batch_size', type=int, default=32,
                        help='Batch size for validation')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='Learning rate for D, E and G')
    parser.add_argument('--f_lr', type=float, default=1e-6,
                        help='Learning rate for F')
    parser.add_argument('--beta1', type=float, default=0.0,
                        help='Decay rate for 1st moment of Adam')
    parser.add_argument('--beta2', type=float, default=0.99,
                        help='Decay rate for 2nd moment of Adam')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='Weight decay for optimizer')
    parser.add_argument('--num_outs_per_domain', type=int, default=10,
                        help='Number of generated images per domain during sampling')

    # misc
    parser.add_argument('--mode', type=str, required=True,
                        choices=['train', 'sample', 'eval', 'align'],
                        help='This argument is used in solver')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='Number of workers used in DataLoader')
    parser.add_argument('--seed', type=int, default=777,
                        help='Seed for random number generator')

    # directory for training
    parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
                        help='Directory containing training images')
    parser.add_argument('--val_img_dir', type=str, default='data/celeba_hq/val',
                        help='Directory containing validation images')
    parser.add_argument('--sample_dir', type=str, default='expr/samples',
                        help='Directory for saving generated images')
    parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints',
                        help='Directory for saving network checkpoints')

    # directory for calculating metrics
    parser.add_argument('--eval_dir', type=str, default='expr/eval',
                        help='Directory for saving metrics, i.e., FID and LPIPS')

    # directory for testing
    parser.add_argument('--result_dir', type=str, default='expr/results',
                        help='Directory for saving generated images and videos')
    parser.add_argument('--src_dir', type=str, default='assets/representative/celeba_hq/src',
                        help='Directory containing input source images')
    parser.add_argument('--ref_dir', type=str, default='assets/representative/celeba_hq/ref',
                        help='Directory containing input reference images')
    parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female',
                        help='input directory when aligning faces')
    parser.add_argument('--out_dir', type=str, default='assets/representative/celeba_hq/src/female',
                        help='output directory when aligning faces')

    # face alignment
    parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt')
    parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz')

    # step size
    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--sample_every', type=int, default=5000)
    parser.add_argument('--save_every', type=int, default=10000)
    parser.add_argument('--eval_every', type=int, default=50000)

    args = parser.parse_args()
    main(args)
python test.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 --checkpoint_dir expr/checkpoints/celeba_hq --result_dir expr/results/celeba_hq --src_dir assets/representative/celeba_hq/src --ref_dir assets/representative/celeba_hq/ref

在这里插入图片描述

在这里插入图片描述

参考文献

[1] StarGAN v2 源代码地址:https://github.com/clovaai/stargan-v2.git
[2] StarGAN v2 论文地址:https://arxiv.org/abs/1912.01865

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
StarGAN v2是一种先进的图像生成模型,旨在将一组输入图像转换为多个可能的目标域图像。该模型具有许多有用的功能和创新。 首先,StarGAN v2建立在StarGAN的基础上,通过引入一个新的概念,即多个生成器和判别器,大大提高了模型的生成能力。每个生成器与一个特定目标域相关联,并且可以从输入图像生成与目标域相关的图像。多个判别器用于提供有关输入图像和生成图像之间的真实性的反馈,从而帮助生成更高质量的图像。 其次,StarGAN v2引入了一个新的概念称为样式代码。样式代码是一个向量,代表了输入图像和目标域之间的潜在特征。通过改变样式代码的值,可以在目标域中生成具有不同外观和特征的图像。这使得模型更加灵活和可控,用户可以根据需要对图像进行个性化的转换。 另外,StarGAN v2还引入了两个重要的改进,称为判别器样式适应和循环一致性损失。判别器样式适应用于提高判别器的性能,使其能够更好地区分生成图像和目标域中真实图像之间的区别。循环一致性损失则用于确保生成器能够在两个目标域之间进行无缝转换,而不会丢失细节或信息。 最后,StarGAN v2通过使用特征对齐损失进一步提高了生成图像的质量。特征对齐损失用于确保在生成图像和真实图像之间的特征分布保持一致,从而使得生成图像更加逼真和真实。 总之,StarGAN v2是一个令人印象深刻的图像生成模型,通过引入多个生成器和判别器、样式代码、判别器样式适应、循环一致性损失和特征对齐损失,实现了高质量和高度可控的图像转换。它在许多应用领域,如人脸生成和图像风格迁移中具有巨大的潜力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FriendshipT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值