StarGAN v2:多领域的不同图像合成
前言
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv10训练自己的数据集(交通标志检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
相关介绍
- [1] StarGAN v2 源代码地址:https://github.com/clovaai/stargan-v2.git
- [2] StarGAN v2 论文地址:https://arxiv.org/abs/1912.01865
StarGAN v2 是一种高级的生成对抗网络 (GAN) 架构,专门设计用于多域图像合成与转换。它是在 StarGAN 的基础上发展起来的,旨在解决多样性和可扩展性的问题。以下是关于 StarGAN v2 的详细介绍及其优缺点:
StarGAN v2 的工作原理
核心思想
- Domain 和 Style 分离:StarGAN v2 将域(domain)和风格(style)的概念分离,允许用户独立控制图像的内容和风格。
- Mapping Network:引入了一个映射网络(mapping network),负责将随机编码映射到不同的伪风格码。
- Style Encoder:使用风格编码器(style encoder)来从真实的图像中获取风格码,这样可以确保生成的图像既具有多样性又保持了原有的风格细节。
主要组件
- Generator:负责生成图像,能够处理多个域的转换。
- Discriminator:用于判断生成图像的真实度,帮助训练过程更加稳定。
- Mapping Network:将随机噪声映射到风格空间。
- Style Encoder:从真实图像中提取风格信息。
优点
- 多样性:能够生成具有高度多样性的图像,避免了生成结果过于单一的问题。
- 可扩展性:能够在一个模型中处理多个不同的域,无需为每个域单独训练模型。
- 统一的框架:提供了统一的框架来处理不同的任务,如人脸属性编辑、动物种类转换等。
- 灵活性:用户可以独立控制图像的内容和风格,增加了使用的灵活性。
- 高质量输出:在多个基准测试上展示了高质量的图像生成能力。
缺点
- 复杂性:引入了多个组件(如映射网络和风格编码器),使得模型结构更加复杂,可能增加了训练难度。
- 训练资源需求:复杂的架构可能需要更多的计算资源和更长的训练时间。
- 泛化能力限制:虽然在特定数据集上表现出色,但可能在其他数据集或领域上的泛化能力有待验证。
- 标签信息限制:在使用多个数据集进行训练时,每个数据集只包含部分标签信息,这可能会影响某些任务的表现。
应用实例
- 人脸属性编辑:例如改变人脸的性别、年龄等属性。
- 动物种类转换:例如将猫的图像转换为狗的图像。
- 其他图像转换任务:如季节变换、绘画风格转移等。
总结
StarGAN v2 是一个多域图像转换的前沿技术,它在多样性和可扩展性方面取得了显著的进步,使其成为图像合成领域的重要工具。不过,它的复杂性也可能带来一些挑战,尤其是在实际部署和训练过程中。
实验环境
python=3.6.7
pytorch=1.4.0
torchvision=0.5.0
ffmpeg=4.0.2
opencv-python==4.1.2.30
ffmpeg-python==0.2.0
scikit-image==0.16.2
pillow==7.0.0
scipy==1.2.1
tqdm==4.43.0
munch==2.5.0
项目地址
- StarGAN v2 源代码地址:https://github.com/clovaai/stargan-v2.git
Linux
git clone https://github.com/clovaai/stargan-v2.git
cd stargan-v2
Windows
请到
https://github.com/clovaai/stargan-v2.git
网站下载源代码zip压缩包。
cd stargan-v2
项目结构
stargan-v2
├─assets
│ └─representative
│ ├─afhq
│ │ ├─ref
│ │ │ ├─cat
│ │ │ ├─dog
│ │ │ └─wild
│ │ └─src
│ │ ├─cat
│ │ ├─dog
│ │ └─wild
│ ├─celeba_hq
│ │ ├─ref
│ │ │ ├─female
│ │ │ └─male
│ │ └─src
│ │ ├─female
│ │ └─male
│ └─custom
│ ├─female
│ └─male
├─core
│ └─__pycache__
├─data
│ └─celeba_hq
│ ├─train
│ │ ├─female
│ │ └─male
│ └─val
│ ├─female
│ └─male
├─expr
│ ├─checkpoints
│ │ └─celeba_hq
│ └─results
│ └─celeba_hq
└─metrics
└─__pycache__
具体用法
数据集和预训练网络
- 源码提供了一个脚本来下载StarGAN v2和相应的预训练网络中使用的数据集。数据集和网络检查点将被下载并分别存储在data和expr/检查点目录中。
- CelebA-HQ。要下载CelebA-HQ数据集和预训练网络,可以运行以下命令:
bash download.sh celeba-hq-dataset
bash download.sh pretrained-network-celeba-hq
bash download.sh wing
- AFHQ。要下载AFHQ数据集和预训练网络,可以运行以下命令:
bash download.sh afhq-dataset
bash download.sh pretrained-network-afhq
进行评估(以celeba-hq为例)
使用 Fréchet Inception Distance (FID) and Learned Perceptual Image Patch Similarity (LPIPS),进行评估StarGAN v2。
# celeba-hq
python main.py --mode eval --num_domains 2 --w_hpf 1 --resume_iter 100000 --train_img_dir data/celeba_hq/train --val_img_dir data/celeba_hq/val --checkpoint_dir expr/checkpoints/celeba_hq --eval_dir expr/eval/celeba_hq
# afhq
python main.py --mode eval --num_domains 3 --w_hpf 0 --resume_iter 100000 --train_img_dir data/afhq/train --val_img_dir data/afhq/val --checkpoint_dir expr/checkpoints/afhq --eval_dir expr/eval/afhq
- 注意,评估指标是使用随机潜在向量或参考图像计算的,这两者都是由种子数选择的。
进行训练(以celeba-hq为例)
要从头开始训练StarGAN v2,运行以下命令。生成的映像和网络检查点将分别存储在expr/samples和expr/检查点目录中。在单个Tesla V100 GPU上训练大约需要三天。请参阅这里的训练参数和它们的描述。
# celeba-hq
python main.py --mode train --num_domains 2 --w_hpf 1 --lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 --train_img_dir data/celeba_hq/train --val_img_dir data/celeba_hq/val
# afhq
python main.py --mode train --num_domains 3 --w_hpf 0 --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 --train_img_dir data/afhq/train --val_img_dir data/afhq/val
进行测试(以celeba-hq为例)
新建一个test.py文件,内容如下。
import os
import argparse
from munch import Munch
from torch.backends import cudnn
import torch
from core.data_loader import get_train_loader
from core.data_loader import get_test_loader
from core.solver import Solver
def str2bool(v):
return v.lower() in ('true')
def subdirs(dname):
return [d for d in os.listdir(dname)
if os.path.isdir(os.path.join(dname, d))]
def main(args):
print(args)
cudnn.benchmark = True
torch.manual_seed(args.seed)
solver = Solver(args)
if args.mode == 'train':
assert len(subdirs(args.train_img_dir)) == args.num_domains
assert len(subdirs(args.val_img_dir)) == args.num_domains
loaders = Munch(src=get_train_loader(root=args.train_img_dir,
which='source',
img_size=args.img_size,
batch_size=args.batch_size,
prob=args.randcrop_prob,
num_workers=args.num_workers),
ref=get_train_loader(root=args.train_img_dir,
which='reference',
img_size=args.img_size,
batch_size=args.batch_size,
prob=args.randcrop_prob,
num_workers=args.num_workers),
val=get_test_loader(root=args.val_img_dir,
img_size=args.img_size,
batch_size=args.val_batch_size,
shuffle=True,
num_workers=args.num_workers))
solver.train(loaders)
elif args.mode == 'sample':
assert len(subdirs(args.src_dir)) == args.num_domains
assert len(subdirs(args.ref_dir)) == args.num_domains
loaders = Munch(src=get_test_loader(root=args.src_dir,
img_size=args.img_size,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.num_workers),
ref=get_test_loader(root=args.ref_dir,
img_size=args.img_size,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.num_workers))
solver.sample(loaders)
elif args.mode == 'eval':
solver.evaluate()
elif args.mode == 'align':
from core.wing import align_faces
align_faces(args, args.inp_dir, args.out_dir)
else:
raise NotImplementedError
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model arguments
parser.add_argument('--img_size', type=int, default=256,
help='Image resolution')
parser.add_argument('--num_domains', type=int, default=2,
help='Number of domains')
parser.add_argument('--latent_dim', type=int, default=16,
help='Latent vector dimension')
parser.add_argument('--hidden_dim', type=int, default=512,
help='Hidden dimension of mapping network')
parser.add_argument('--style_dim', type=int, default=64,
help='Style code dimension')
# weight for objective functions
parser.add_argument('--lambda_reg', type=float, default=1,
help='Weight for R1 regularization')
parser.add_argument('--lambda_cyc', type=float, default=1,
help='Weight for cyclic consistency loss')
parser.add_argument('--lambda_sty', type=float, default=1,
help='Weight for style reconstruction loss')
parser.add_argument('--lambda_ds', type=float, default=1,
help='Weight for diversity sensitive loss')
parser.add_argument('--ds_iter', type=int, default=100000,
help='Number of iterations to optimize diversity sensitive loss')
parser.add_argument('--w_hpf', type=float, default=1,
help='weight for high-pass filtering')
# training arguments
parser.add_argument('--randcrop_prob', type=float, default=0.5,
help='Probabilty of using random-resized cropping')
parser.add_argument('--total_iters', type=int, default=100000,
help='Number of total iterations')
parser.add_argument('--resume_iter', type=int, default=0,
help='Iterations to resume training/testing')
parser.add_argument('--batch_size', type=int, default=4,
help='Batch size for training')
parser.add_argument('--val_batch_size', type=int, default=32,
help='Batch size for validation')
parser.add_argument('--lr', type=float, default=1e-4,
help='Learning rate for D, E and G')
parser.add_argument('--f_lr', type=float, default=1e-6,
help='Learning rate for F')
parser.add_argument('--beta1', type=float, default=0.0,
help='Decay rate for 1st moment of Adam')
parser.add_argument('--beta2', type=float, default=0.99,
help='Decay rate for 2nd moment of Adam')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='Weight decay for optimizer')
parser.add_argument('--num_outs_per_domain', type=int, default=10,
help='Number of generated images per domain during sampling')
# misc
parser.add_argument('--mode', type=str, required=True,
choices=['train', 'sample', 'eval', 'align'],
help='This argument is used in solver')
parser.add_argument('--num_workers', type=int, default=0,
help='Number of workers used in DataLoader')
parser.add_argument('--seed', type=int, default=777,
help='Seed for random number generator')
# directory for training
parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
help='Directory containing training images')
parser.add_argument('--val_img_dir', type=str, default='data/celeba_hq/val',
help='Directory containing validation images')
parser.add_argument('--sample_dir', type=str, default='expr/samples',
help='Directory for saving generated images')
parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints',
help='Directory for saving network checkpoints')
# directory for calculating metrics
parser.add_argument('--eval_dir', type=str, default='expr/eval',
help='Directory for saving metrics, i.e., FID and LPIPS')
# directory for testing
parser.add_argument('--result_dir', type=str, default='expr/results',
help='Directory for saving generated images and videos')
parser.add_argument('--src_dir', type=str, default='assets/representative/celeba_hq/src',
help='Directory containing input source images')
parser.add_argument('--ref_dir', type=str, default='assets/representative/celeba_hq/ref',
help='Directory containing input reference images')
parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female',
help='input directory when aligning faces')
parser.add_argument('--out_dir', type=str, default='assets/representative/celeba_hq/src/female',
help='output directory when aligning faces')
# face alignment
parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt')
parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz')
# step size
parser.add_argument('--print_every', type=int, default=10)
parser.add_argument('--sample_every', type=int, default=5000)
parser.add_argument('--save_every', type=int, default=10000)
parser.add_argument('--eval_every', type=int, default=50000)
args = parser.parse_args()
main(args)
python test.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 --checkpoint_dir expr/checkpoints/celeba_hq --result_dir expr/results/celeba_hq --src_dir assets/representative/celeba_hq/src --ref_dir assets/representative/celeba_hq/ref
参考文献
[1] StarGAN v2 源代码地址:https://github.com/clovaai/stargan-v2.git
[2] StarGAN v2 论文地址:https://arxiv.org/abs/1912.01865
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv10训练自己的数据集(交通标志检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目