模型简介
StarGAN 是一种生成对抗网络(GAN),它特别适用于图像到图像的转换任务。与传统的图像转换方法不同,StarGAN 能够一次性在多个域(domains)之间进行转换,而不需要为每个转换任务训练单独的模型。
多域图像转换:与其他方法不同,StarGAN 可以在多个域之间进行转换,而不仅仅是两个域。这是通过为每个图像添加一个域标签(domain label)来实现的。
共享生成器:只使用一个生成器来处理不同的转换任务。生成器根据输入的域标签,生成相应目标域的图像。
共享判别器:一个判别器同时用于区分真假图像和识别图像的域标签。
模型拆解
网络结构
损失函数
评价指标
模型调试
训练指令
python main.py --mode train --num_domains 2 --w_hpf 1 \
--lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 \
--train_img_dir data/celeba_hq/train \
--val_img_dir data/celeba_hq/val
python main.py --mode train --num_domains 2 --w_hpf 1 \
--lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 \
--train_img_dir D:\BaiduNetdiskDownload\FFHQ_Train \
--val_img_dir D:\BaiduNetdiskDownload\FFHQ_Train
常见问题
Q1:损失函数的数值为nan?
问题分析
在模型训练的第一轮(也就是第一个epoch或iteration)中损失函数的数值为nan,这通常会严重影响后续的训练。一旦损失函数变成nan,模型的权重更新通常也会变得无效,因为nan会通过梯度传播影响整个网络。这样的话,后续的训练循环也将继续产生 nan
的损失值,除非问题得到解决。
解决方案
检查初始化:确保网络的权重初始化适合所选的激活函数。例如,对于ReLU激活函数,He初始化通常是一个好选择。
调整学习率:尝试使用较小的学习率开始训练,查看是否可以避免出现 nan
。
检查数据:确保输入数据没有问题,例如没有包含无效值(如无穷大或非数值)。
使用梯度裁剪:梯度裁剪可以帮助防止梯度爆炸,这是导致 nan
的一个常见原因。
调试输出:在模型的关键点添加调试输出,以查看何时何地首次生成了 nan
值。这可以帮助识别是数据问题、实现错误还是数值稳定性问题。
相关调研
文献检索(谷歌学术/知网)
技术概览(google/github/huggingface/csdn/知乎/bing/百度/b站)
[1] github项目:https://github.com/clovaai/stargan-v2
[2] error: https://eternallybored.org/misc/wget/
[3] net-weight1: https://www.dropbox.com/s/96fmei6c93o8b8t/100000_nets_ema.ckpt?e=2&dl=0
[4] net-weight2: https://www.dropbox.com/s/tjxpypwpt38926e/wing.ckpt?e=1&dl=0
[5] error: https://blog.csdn.net/m0_52122736/article/details/115802983
[6] error: https://www.cnblogs.com/yanzhao-x/p/15984054.html
[7] github项目:https://github.com/a312863063/generators-with-stylegan2