文章目录
原文翻译
MobileStyleGAN:一种用于高保真图像合成的轻量级卷积神经网络
github:https://github.com/bes-dev/MobileStyleGAN.pytorch
paper:https://arxiv.org/pdf/2104.04767
Abstract
近年来,生成对抗网络(GANs)在生成图像建模中变得非常流行。虽然基于风格的GAN架构在高保真图像合成方面产生了最先进的结果,但在计算上,它们非常复杂。在我们的工作中,我们专注于基于风格的生成模型的性能优化。我们分析了StyleGAN2中计算难度最大的部分,并提出了对生成器网络的修改,以使在边缘设备中部署基于样式的生成网络成为可能。我们引入了MobileStyleGAN架构,它的参数比StyleGAN2少3.5倍,计算复杂度比StyleGAN2少9.5倍,同时提供了相当的质量。
1. Introduction
近年来,利用生成对抗网络(GANs)[9]显著提高了高保真图像合成。虽然DCGAN[27]等早期工作可以生成分辨率高达64x64像素的图像,但BigGAN[3]和StyleGAN[20,21,19]等现代网络允许生成高达512x512甚至1024x1024像素的逼真图像。尽管生成模型的质量有了显着提高,但图像生成仍然需要许多计算资源。高计算复杂度使得将最先进的生成模型部署到边缘设备变得困难。
例如,StyleGAN2[21]网络允许FFHQ数据集的真实人脸图像大小为1024x1024像素,FID=2.84。然而,它包含28.27M个参数,计算复杂度为143.15GMAC。
我们提出了一种新的轻量级架构,MobileStyleGAN,一种用于高质量图像生成的高分辨率生成模型。以原始StyleGAN2架构为基准,我们重新访问该网络中计算困难的部分,以创建我们自己的轻量级模型,提供类似的质量(图1)。整个网络包含8.01M参数,计算复杂度为15.09 GMAC,并为FFHQ数据集提供FID=7.75的质量。
我们的主要贡献是:
•我们引入了端到端基于小波的卷积神经网络,用于高保真图像合成。
•我们引入深度可分离调制卷积作为调制卷积的轻量级版本,以降低计算复杂性。
•我们引入了适用于图形优化(如操作融合)的解调机制的重新审视版本。我们提出了一个基于知识蒸馏的管道来训练我们的网络。
2. Related Work
2.1. StyleGAN
StyleGAN[20]是一个用于高分辨率图像生成的现代生成模型。StyleGAN网络的主要特点是:
•采用渐进式增长,逐步提高分辨率。
•它从固定值张量生成图像,而不是像传统gan那样从随机生成的潜在变量生成图像。
•随机生成的潜在变量经过8层神经网络的非线性变换后,通过AdaIN[16]在每个分辨率上作为风格向量。
StyleGAN2[21]在StyleGAN的基础上进行了改进:
•通过使用估计统计数据进行归一化而不是使用AdaIN等实际统计数据进行归一化来消除液滴模式。
•通过使用跳跃式连接的分层发生器而不是渐进生长来减少眼睛和牙齿停滞。
•通过减少PPL和平滑潜在空间来提高图像质量。
StyleGAN2-ADA[19]通过使用自适应鉴别器增强,使StyleGAN适用于数据有限的任务。
2.2. Model acceleration
深度学习的一个重要研究方向是通过手动和自动设计轻量级架构来加速卷积神经网络(cnn)。
一些工作集中在设计有效的神经分类网络,可以用作其他任务的骨干。Howard等人[15,14,28]为移动和嵌入式视觉应用提出了称为mobilenet的高效模型。这些观点在许多文章中得到了改进[29,17,8,32,26]。
其他作品集中于高效生成模型的设计模型。Li等人[23]提出了一种基于蒸馏和神经结构搜索的条件gan自动优化通用框架[4]。Chang和Lu[5]提出了一种用于BigGAN压缩的蒸馏管道。Liu等[24]提出了一种基于注意机制的轻量级GAN。
2.3. Knowledge distillation
Hinton等人[13]提出了利用大型教师网络训练小型学生网络的知识蒸馏方法。蒸馏的主要思想是训练学生网络来模仿教师网络的行为。Aguinaldo等[2]采用知识蒸馏加速无条件gan。其他一些与gan相关的工作[30,6,23,5]也将知识蒸馏作为其管道的一部分。
2.4. 小波变换
在深度学习中使用基于小波的方法并不新鲜。小波[10]已经应用于许多计算机视觉任务,如纹理分类[7]、图像恢复[25]和超分辨率[31]。
Han等人[11]提出了基于两个子网络的not - big - gan架构:低分辨率的生成网络和上采样的超分辨率网络。作者表明,基于小波的子网络优于基于像素的方法。
与之前的工作相比,我们提出了一种端到端基于小波的CNN生成网络架构。我们表明,将基于小波的方法集成到gan中可以设计更轻量级的网络,并提供更平滑的潜在空间。
3. MobileStyleGAN Architecture
我们提出的架构 MobileStyleGAN 基于与基于样式的生成模型相关的先前工作。MobileStyleGAN包括一个映射网络和合成网络,如StyleGAN2所示。我们采用StyleGAN2的映射网络,专注于计算效率高的合成网络设计。
我们现在描述我们提出的体系结构和基本StyleGAN2之间的主要区别。然后,我们描述了MobileStyleGAN网络的基于蒸馏的训练过程。
3.1. 重新审视图像表示
虽然StyleGAN2适用于基于像素的图像表示,旨在直接预测输出图像的像素值,但在我们的工作中,我们使用了基于频率的图像表示。这样,MobileStyleGAN试图预测输出图像的离散小波变换(DWT)。
当应用于 2d 图像时,DWT 将通道转换为四个大小相等的通道,具有更低的空间分辨率和不同的频带。逆离散小波变换(IDWT)然后从小波域重构基于像素的表示(图2)。
图2。(左)原始RGB图像。(右)图像的DWT分解。A、B、C、D 是位于 HR 图像的左上角网格 2x2 中的四个像素。
这种类型的图像表示有几个优点,例如:
• 由于基于小波的图像表示比基于像素的方法包含更多的结构信息,因此我们可以使用低分辨率特征图生成高分辨率图像而不会损失准确性:
• 在我们的工作中,我们使用Haar小波[10]作为DWT和IDWT的滤波器组。在没有乘法运算的情况下,可以有效地实现使用Haar小波的IDWT(图2):
• 图像的高频细节的生成是一个复杂的问题。虽然StyleGAN的潜在空间在低频下平滑,但在高频上粗糙。与基于像素的方法相比,使用基于频率的图像表示,我们可以直接将正则化添加到信号的高频分量上,这使得潜在空间在低频和高频上都平滑。
3.2.重新审视渐进式增长
StyleGAN2 使用 skip-generator 通过从同一图像的多个分辨率显式求和 RGB 值来形成输出图像。我们发现,当我们预测小波域中的图像时,基于跳跃连接的预测头不会对生成的图像的质量做出重要贡献。因此,为了降低计算复杂度,我们将skipgenerator替换为网络最后一个块中的单个预测头。然而,从中间块预测目标图像对于稳定图像合成很重要。因此,我们为每个中间块添加一个辅助预测头,根据其空间分辨率预测目标图像。
StyleGAN2 和 MobileStyleGAN 预测头之间的差异如图 3 所示。
图3。(左)StyleGAN2预测头。(右)MobileStyleGAN预测头。
3.3.深度可分离卷积
受MobileNet[15]的启发,MobileStyleGAN基于深度可分离卷积,一种分解卷积的形式,将标准卷积分解为3x3深度卷积和称为点卷积的一维卷积。
如[21]所述,调制卷积由调制、卷积和归一化组成(图4的左侧面板)。深度可分离卷积还包括这些部分(图4的中间面板)。然而,我们注意到,虽然 StyleGAN2 描述了应用于权重的调制/解调,但我们将它们分别应用于输入/输出激活。这种操作顺序使得很容易描述深度可分离卷积。让我们首先考虑调制和卷积的影响。调制根据传入的样式缩放卷积的每个输入特征图:
其中 x 和 x′ 分别是原始和调制的输入激活,s 是对应于输入特征图的尺度。然后我们依次应用 3x3 深度卷积和 1x1 逐点卷积,它们之间没有任何非线性:
现在我们应用解调从输出特征图的统计数据中去除 s 的影响。由于卷积算子的线性,顺序应用的深度卷积和像素卷积的结果等于应用密集卷积的结果:
这样,解调系数可以计算为:
其中 i 和 j 和 k 分别枚举卷积的输入/输出特征图和空间足迹。为了解调输出特征图,我们应用:
图4。(左)调制卷积。(中)深度可分离卷积。(右)具有可训练解调的深度可分离调制卷积。
3.4.解调融合
批量归一化融合是一种流行的技术,可以在推理时降低卷积网络的计算复杂度。这种技术依赖于我们可以将两个线性操作合并为一个事实。解调机制类似于批处理归一化,但在推理时不是线性操作。在等式 6 之后,其中权重是固定的,解调是样式的函数。为了使解调常数,我们将样式系数替换为可训练参数(图 4 的右侧面板):
因此,解调在推理时变为常数,可以合并到像素卷积权重中。我们发现这种技术不会对生成的图像的质量产生不利影响。
3.5. 重新审视 Upscale
虽然StyleGAN2构建块使用ConvTranspose(图5的左侧面板)来升级输入特征映射,但在第3.1节中,我们使用IDWT作为MobileStyleGAN构建块中的高级函数(图5的右侧面板)。由于IDWT不包括可训练参数,我们在IDWT层之后添加额外的深度可分离卷积。
StyleGAN2和MobileStyleGAN的完整构建块结构如图6所示。
图5。(左)StyleGAN2 upscale块。(右)MobileStyleGAN upscale块。箭头上的黄色矩形显示了相关特征图中的通道数。
图6。(左)StyleGAN2构建块。(右)MobileStyleGAN构建块。
4. Training framework
与之前的工作[5,23]一样,我们的训练框架基于知识蒸馏技术[13]。给定StyleGAN2[21]作为教师网络,我们训练MobileStyleGAN模仿其功能。整个框架如图7所示。在本节中,我们将讨论训练框架的主要部分。
4.1. Data preparation
给定原始的 StyleGAN2 生成器,我们可以将未配对的学习转换为配对设置。为此,我们准备了三元组数据 {style, noise,
I
t
e
a
c
h
e
r
I_{teacher}
Iteacher },其中 style 是给定噪声向量 z 的映射网络的输出,噪声是教师和学生网络之间共享的噪声,
I
t
e
a
c
h
e
r
I_{teacher}
Iteacher是教师网络对给定风格的输出。
如第 3.2 节所述,MobileStyleGAN 的每个块预测其空间大小的输出图像。因此,我们不使用
I
t
e
a
c
h
e
r
I_{teacher}
Iteacher,而是使用
I
t
e
a
c
h
e
r
p
y
r
a
m
i
d
I^{pyramid }_{teacher}
Iteacherpyramid作为基本事实。
I
t
e
a
c
h
e
r
p
y
r
a
m
i
d
I^{pyramid }_{teacher}
Iteacherpyramid是从
I
t
e
a
c
h
e
r
I_{teacher}
Iteacher构建的图像金字塔。因此,我们训练数据表示为三元组数据{style, noise,
I
t
e
a
c
h
e
r
p
y
r
a
m
i
d
I^{pyramid }_{teacher}
Iteacherpyramid}。
为了防止过度拟合,我们不使用预处理数据集。相反,我们在学习过程中动态生成数据。
为了减少学习过程中的内存消耗,我们只使用StyleGAN2生成的人工样本,而不使用真实数据。
4.2.训练目标
我们现在详细说明知识蒸馏的建议目标。
像素级蒸馏损失(图8)。由于MobileStyleGAN旨在预测小波域中的目标图像,模拟StyleGAN2功能的朴素方法最小化了StyleGAN2生成的图像小波变换与MobileStyleGAN输出之间的像素级距离。此外,我们添加了一个正则化项,以最小化我们的地面实况和基于像素的域中的预测图像之间的像素级距离。我们发现这个术语允许我们相互同步训练不同的频率。如第 3.2 节所述,我们的网络在每个空间大小预测输出图像。因此,在每个尺度上应用基于像素的蒸馏损失。
形式上,让
其中
F
s
i
F^{i}_{s}
Fsi 是学生网络第 i 个块预测的小波域中的图像,
I
t
i
I^{i}_{t}
Iti 是相应空间大小的基于像素的域中的真实图像。
图8。像素级蒸馏损失。
感知损失。前面描述的像素级损失没有捕捉到输出图像和地面真实图像之间的感知差异。为了解决这个问题,我们使用感知损失作为目标。我们的感知损失基于 VGG16 特征并实现,如 [18] 中所述。我们只将感知损失应用于MobileStyleGAN生成的输出图像。
形式上,让
其中
V
G
G
16
(
.
.
.
)
l
VGG16(...)_{l}
VGG16(...)l 是 VGG16 的相应层
l
l
l ∈ [relu1_2, relu2_2, relu3_3, relu4_3] 的中间特征,
I
s
256
x
256
I^{256x256}_{s}
Is256x256 是学生网络预测的输出图像(调整为 256x256),其中
I
t
256
x
256
I^{256x256}_{t}
It256x256 是教师网络预测的输出图像(调整为 256x256)。
GAN损失。仅使用像素级和感知损失会导致模糊图像的生成。为了锐化生成的图像,我们在管道中包含一个鉴别器网络。我们对生成器采用 GAN 损失:
对于鉴别器网络:
其中
f
(
t
)
=
−
l
o
g
(
1
+
e
x
p
(
−
t
)
)
f (t) = − log(1 + exp(−t))
f(t)=−log(1+exp(−t)) 是 softplus 函数,
D
T
(
.
.
.
)
DT (...)
DT(...) 是具有可微增强的鉴别器网络,
G
s
t
u
d
e
n
t
(
.
.
.
)
/
G
t
e
a
c
h
e
r
(
.
.
.
)
G_{student}(...)/G_{teacher} (...)
Gstudent(...)/Gteacher(...) 是学生/教师网络,(风格、噪声) 是输入配对数据。鉴别器具有与 [20] 中相同的拓扑。
我们发现当我们压缩生成器网络时,很难平衡生成器和鉴别器网络之间的容量。因此,R1 正则化是蒸馏管道 GAN Loss 的重要组成部分,它允许在生成器和鉴别器之间重新校准容量不平衡。
完整的目标。最后,我们将完整的目标定义为:
其中超参数 λ1、λ2、λ3 控制每个项的重要性。
5. Experiments
为了训练我们的MobileStyleGAN网络,我们使用在FFHQ[20]数据集上训练的StyleGAN2教师网络。
5.1. Training
MobileStyleGAN 使用 Adam 算法 [22] 进行训练,β1 = 0.9,β2 = 0.999。生成器和判别器学习率都设置为常数 5 e − 4 5e^{-4} 5e−4。我们使用仿射变换和裁剪作为判别器输入处的可微增强。目标函数集的超参数固定为λ1 = 1.0, λ2 = 1.0, λ3 = 0.1。我们在每个优化步骤更新生成器和鉴别器。在 4 x NVIDIA 2080Ti GPU 上训练大约需要 3 天,批量大小为=8。
5.2. Results
表 1 显示了我们在 FFHQ 数据上评估 MobileStyleGAN 的结果。我们比较了MobileStyleGAN和教师网络(StyleGAN2)的参数数量、计算成本和Frechet起始距离(FID)[12]。
此外,我们评估了 StyleGAN2 和 MobileStyleGAN 在 CPU 上的推理时间。我们使用配备 Intel® Core™ i5-8279U CPU 的笔记本电脑进行实验。表 2 显示了推理时间评估的结果。正如我们之前注意到的,我们在MobileStyleGAN中使用的一些架构技巧提供了可用的图优化,如恒定折叠等。当使用自动应用图优化的OpenVINO[1]等特殊推理引擎时,我们获得了更好的性能。
表1。主要结果。FFHQ 数据集的 StyleGAN 和 MobileStyleGAN 之间的比较,1024x1024。
表 2. CPU (Intel® Core™ i5-8279U) 上的推理时间。
6. Conclusion
在这项工作中,我们解决了高保真图像合成的问题,适用于边缘设备的部署。提出了一种新的基于知识蒸馏的风格的轻量级生成网络和训练管道。我们已经公开了我们的训练代码(https://github.com/bes-dev/MobileStyleGAN.pytorch),以及一个简单的 python 库,用于在 CPU (https://github.com/bes-dev/random_face) 上快速随机人脸合成。随附的视频可以在 YouTube (https://www.youtube.com/playlist?list=PLstKhmdpWBtwsvq_27ALmPbf_mBLmk0uI) 上找到。
一些技术可以进一步提高性能和准确性,例如量化和修剪。我们将它们留作未来的研究。