数据集
在不同分辨率下在不同数据集上训练的预训练 StyleGAN 模型的集合。
Dataset | images | information |
---|---|---|
LSUN Bedrooms | ||
LSUN Cars | ||
LSUN Cats | ||
CelebA HQ Faces | ||
FFHQ Faces | ||
Pokemon | ||
Anime Faces | ||
Anime Portraits | ||
WikiArt Faces | ||
Abstract Photos | ||
Vases | ||
Fireworks | ||
Ukiyo-e Faces | ||
Butterflies |
使用预训练网络
pretrained_example.py 中给出了使用预训练 StyleGAN 生成器的最小示例。 执行时,脚本会从 Google Drive 下载一个预先训练好的 StyleGAN 生成器,并使用它来生成图像:
> python pretrained_example.py
Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done
Gs Params OutputShape WeightShape
--- --- --- ---
latents_in - (?, 512) -
...
images_out - (?, 3, 1024, 1024) -
--- --- --- ---
Total 26219627
> ls results
example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
generate_figures.py 中给出了一个更高级的示例。 该脚本复制了我们论文中的数字,以说明样式混合、噪声输入和截断:
> python generate_figures.py
results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu
results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6
results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG
results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_
results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v
results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr
results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke
results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
预训练的网络作为标准 pickle 文件存储在 Google Drive 上:
# Load pre-trained network.
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
上面的代码下载文件并解压它以产生 dnnlib.tflib.Network 的 3 个实例。 要生成图像,您通常需要使用 Gs——另外两个网络是为了完整性而提供的。 为了使 pickle.load() 工作,您需要在 PYTHONPATH 中包含 dnnlib 源目录,并将 tf.Session 设置为默认值。 会话可以通过调用 dnnlib.tflib.init_tf() 来初始化。
使用预训练生成器的三种方式:
-
使用 Gs.run() 进行立即模式操作,其中输入和输出是 numpy 数组:
第一个参数是一批形状为 [num, 512] 的潜在向量。 第二个参数是为类标签保留的(StyleGAN 不使用)。 其余的关键字参数是可选的,可用于进一步修改操作(见下文)。 输出是一批图像,其格式由 output_transform 参数指定。# Pick latent vector. rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
-
使用 Gs.get_output_for() 将生成器合并为更大的 TensorFlow 表达式的一部分:
上面的代码来自metrics/frechet_inception_distance.py。 它生成一批随机图像并将它们直接提供给 Inception-v3 网络,而无需将数据转换为中间的 numpy 数组。latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) images = tflib.convert_images_to_uint8(images) result_expr.append(inception_clone.get_output_for(images))
-
查找 Gs.components.mapping 和 Gs.components.synthesis 以访问生成器的各个子网络。 与 Gs 类似,子网络表示为 dnnlib.tflib.Network 的独立实例:
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
上面的代码来自 generate_figures.py。 它首先使用映射网络将一批潜在向量转换为中间 W 空间,然后使用合成网络将这些向量转换为一批图像。 dlatents 数组为合成网络的每一层存储相同 w 向量的单独副本,以促进风格混合。
生成器的确切细节在 training/networks_stylegan.py 中定义(参见 G_style、G_mapping 和 G_synthesis)。 可以指定以下关键字参数来修改调用 run() 和 get_output_for() 时的行为:
- truncation_psi 和 truncation_cutoff 控制使用 Gs (ψ=0.7, cutoff=8) 时默认执行的截断技巧。 可以通过设置 truncation_psi=1 或 is_validation=True 来禁用它,并且可以通过设置以变化为代价进一步提高图像质量,例如 截断_psi=0.5。 请注意,直接使用子网时,截断始终处于禁用状态。 可以使用 Gs.get_var(‘dlatent_avg’) 查找手动执行截断技巧所需的平均 w。
- randomize_noise 确定是否对每个生成的图像使用重新随机化噪声输入(True,默认)或是否对整个 minibatch 使用特定的噪声值(False)。 可以通过使用 [var for name, var in Gs.components.synthesis.vars.items() if name.startswith(‘noise’)] 找到的 tf.Variable 实例访问特定值。
- 直接使用映射网络时,您可以指定 dlatent_broadcast=None 以禁用合成网络层上的 dlatents 自动复制。
- 运行时性能可以通过 structure=‘fixed’ 和 dtype=‘float16’ 进行微调。 前者禁用了对完全训练生成器不需要的渐进式增长的支持,后者使用半精度浮点算法执行所有计算。
准备训练数据集
训练和评估脚本对存储为多分辨率 TFRecord 的数据集进行操作。每个数据集都由一个目录表示,该目录包含多种分辨率的相同图像数据,以实现高效的流式传输。每个分辨率都有一个单独的 *.tfrecords 文件,如果数据集包含标签,它们也会存储在单独的文件中。 默认情况下,脚本希望在 datasets//-.tfrecords 中找到数据集。 可以通过编辑 config.py 来更改目录:
result_dir = 'results'
data_dir = 'datasets'
cache_dir = 'cache'
要获取 FFHQ 数据集 (datasets/ffhq),请参阅 Flickr-Faces-HQ 存储库。
要获取 CelebA-HQ 数据集 (datasets/celebahq),请参阅Progressive GAN 存储库 。
要获取其他数据集,包括 LSUN,请查阅其相应的项目页面。 可以使用提供的 dataset_tool.py 将数据集转换为多分辨率 TFRecords:
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256
> python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384
> python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256
> python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10
> python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
训练网络
设置数据集后,您可以按如下方式训练自己的 StyleGAN 网络:编辑 train.py 以通过- - 取消注释或编辑特定行来指定数据集和训练配置。
- 使用 python train.py 运行训练脚本。
- 结果将写入新创建的目录 results/-。
- 培训可能需要几天(或几周)才能完成,具体取决于配置。
默认情况下,train.py 被配置为使用 8 个 GPU 以 1024×1024 分辨率为 FFHQ 数据集训练最高质量的 StyleGAN(表 1 中的配置 F)。
使用 Tesla V100 GPU 的默认配置的预期训练时间:
评估质量和解开
我们论文中使用的质量和解开度量可以使用 run_metrics.py 进行评估。 默认情况下,脚本将评估预训练的 FFHQ 生成器的 Fréchet 起始距离 (fid50k),并将结果写入结果下新创建的目录中。 可以通过取消注释或编辑 run_metrics.py 中的特定行来更改确切的行为。
使用一个 Tesla V100 GPU 的预训练 FFHQ 生成器的预期评估时间和结果:
请注意,由于 TensorFlow 的不确定性,确切的结果可能会因运行而异。
参考资料
ustinpinkney/awesome-pretrained-stylegan
BVlabs/stylegan