StyleMapGAN: Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing(2021)

47 篇文章 3 订阅
41 篇文章 10 订阅

[paper] Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing
[code] StyleMapGAN
在这里插入图片描述

摘要

生成对抗网络(GAN)从随机潜在向量合成逼真的图像。尽管操纵潜矢量控制了合成输出,但是用GAN编辑真实图像的缺点是:i)耗时的优化,无法将真实图像投影到潜矢量; ii)或通过编码器嵌入不准确。我们提出StyleMapGAN:中间潜在空间具有空间维度,而空间变异调制取代了AdaIN。与现有的基于优化的方法相比,它可通过编码器进行嵌入,同时保持GAN的属性。实验结果表明,在各种图像处理任务(例如本地编辑和图像插值)中,我们的方法明显优于最新模型。最后但并非最不重要的一点是,GAN上的常规编辑方法在我们的StyleMapGAN上仍然有效。

介绍

生成对抗网络(GANs)[16]在最近几年中发生了巨大的变化,可以使用直接从数据中学习的模型进行高保真图像合成[6,25,26]。最近的研究表明,GAN自然会学习在潜在空间内编码丰富的语义,因此更改潜在代码会导致操纵输出图像的相应属性[22、47、17、15、48、3、57、5]。然而,将这些操作应用于真实图像仍然是一个挑战,因为GAN缺少从图像返回到其对应的潜在代码的逆映射。

操纵真实图像的一种有前途的方法是图像到图像的转换[21、64、9、27、29],在该模型中,模型可以学习在给定用户输入的情况下直接合成输出图像。但是,这些方法需要预先定义的任务和严格的监督(例如,输入输出对,类标签)来进行训练,并且限制了用户在推理时的可控制性。另一种方法是通过直接优化单个图像的潜码来利用预训练的GAN模型[1,2,63,37,41]。但是,即使在高端GPU上,每个目标图像也需要几分钟的计算,并且不能保证将优化的代码放置在GAN的原始潜在空间中。

一种更实用的方法是训练一个额外的编码器,该编码器学习将图像投影到其相应的潜在代码中[34、62、44、36、45]。尽管这种方法能够以单前馈方式进行实时投影,但是它遭受了投影图像的低保真度的困扰(即,丢失了目标图像的细节)。我们将此限制归因于潜在空间中没有空间尺寸。如果没有空间维度,编码器会以纠缠的方式将图像的局部语义压缩为矢量,从而难以重建图像(例如,基于矢量或低分辨率的瓶颈层无法产生高频细节[33,8 ])。

为了解决这些问题,我们提出了StyleMapGAN,它利用了stylemap(一种潜在空间的新颖表示形式)。我们的关键思想很简单。 而不是学习基于矢量的潜在表示,我们使用具有显式空间尺寸的张量。我们提出的表示法得益于其空间维度,使GAN可以轻松地将图像的局部语义编码到潜在空间中。此属性使编码器可以将图像有效地投影到潜在空间中,从而提供高保真度和实时投影。我们的方法还提供了一种新功能,可通过操纵样式图的匹配位置来编辑图像的特定区域。图1显示了我们的本地编辑和本地语义操作结果。请注意,所有编辑都是实时进行的。 如图2所示,您可以测试我们的Web演示以进行交互式编辑。
在这里插入图片描述
与传统的基于矢量的潜在表示(第4.3节)相比,我们的样式图确实在多个数据集上大大提高了投影质量。此外,我们在图像投影,插值和局部编辑(第4.4和第4.5节)方面展示了我们的方法相对于最新方法的优势。最后,我们证明了即使在一个图像与另一个图像之间未对齐的区域中,我们的方法也可以移植区域(第4.6节)。

相关工作

基于优化的编辑方法

迭代地更新预训练GAN的潜矢量,以将真实图像投影到潜空间[63,7,1,62,20,4]。例如,Image2StyleGAN [1]通过优化StyleGAN [25]每一层的中间表示来重建图像。In-DomainGAN [62]不仅着重于在像素空间中重建图像,而且着重于将反码置于原始潜在空间的语义域中。神经拼贴[53]和pix2latent [20]提出了一种混合优化策略,用于将图像投影到类条件GAN(例如BigGAN [6])的潜在空间中。另一方面,我们利用了编码器,该编码器的编辑速度比优化方法快了2到3个数量级。

基于学习的编辑方法

训练额外的编码器以直接推断给定目标图像的潜在代码[34、13、12、14、45]。例如,ALI [14]和BiGAN [12]引入了一个完全对抗的框架,以共同学习生成器和逆映射。为了将可变自动编码器[32]与GAN一起用于潜在投影,已经进行了一些工作[34、51、55]。ALAE [45]建立了一个编码器来预测StyleGAN的中间潜在空间。然而,由于缺乏潜在空间的空间尺寸,所有上述方法都提供了有限的重建质量。交换自动编码器[43]学习将图像编码为两个组件,即结构代码和纹理代码,并在给定任何交换组合的情况下生成逼真的图像。尽管由于这样的表示,它可以快速而精确地重建图像,但是纹理代码仍然是一个矢量,这使结构化纹理传递具有挑战性。我们的编辑方法不仅可以成功反映参考图像的颜色和纹理,还可以反映其形状。

本地编辑方法

解决了编辑特定部分[11、3、65、60、49](例如鼻子,背景)的问题,这与大多数基于GAN的修改全局外观的图像编辑方法相反[47、57、43]。例如,样式编辑[11]试图确定每个通道的每层样式矢量对特定部分的贡献。结构噪声[3]用输入张量替换了StyleGAN的学习常数,该张量是局部和全局代码的组合。然而,这些方法[11、3、5]并不针对真实图像,其性能在真实图像中显着降低。SEAN [65]通过将图像编码为每个区域的样式代码并对其进行操作来简化实际图像的编辑,但是它需要成对的图像和分割蒙版进行训练。此外,样式代码仍然是向量,因此它具有与交换自动编码器[43]相同的问题。

StyleMapGAN

我们的目标是使用编码器将图像实时准确地投影到潜在空间,并在潜在空间上本地操作图像。我们提出了StyleMapGAN,它采用样式图,具有空间维度的中间潜在空间以及基于样式图(第3.1节)的空间变体调制。请注意,样式[25]之后不仅表示纹理(精细样式),还表示形状(粗略样式)。现在,编码器可以将图像嵌入样式表中,从而比基于优化的方法更准确地重建图像,并且样式表中的部分更改会导致对图像进行本地编辑(第3.3节)。

基于样式图的生成器

图3描述了建议的基于样式图的生成器。尽管传统的映射网络会生成样式矢量来控制特征图,但我们会创建具有空间尺寸的样式图,这不仅使真实图像的投影在推理上更加有效,而且还可以进行局部编辑。映射网络在最后具有一个重塑层以生成样式图,该样式图形成了空间变化的仿射参数的输入。由于合成网络中的特征图随着靠近输出图像而变得更大,因此,我们引入了一个样式图调整器,该大小调整器由卷积和上采样组成,以将样式图的分辨率与特征图进行匹配。样式图调整大小器使用学习的卷积来调整样式图的大小并对其进行转换,以传达更详细和结构化的样式。
在这里插入图片描述
然后,仿射变换生成有关调整大小的样式图的调制参数。综合网络中第 i i i层的调制操作如下:
在这里插入图片描述
我们删除了逐像素噪声,这是StyleGAN中空间变化输入的额外来源,因为我们的样式图已经提供了空间变化的输入,而单个输入使投影和编辑更加简单。有关网络以及与自动编码器方法的关系的其他详细信息,请参见补充材料(§C)[19]。

训练过程和损失

在图4中,为了简洁起见,我们分别使用F,G,E和D分别表示映射网络,具有样式图缩放器,编码器和鉴别器的综合网络。D与StyleGAN2相同,并且E的体系结构与D相似,不同之处在于没有小批量鉴别[46]。如表1所示,所有网络都经过多次损失共同训练。对G和E进行了训练,以重建像素级和感知级的真实图像[61]。当G(F)从z合成图像时,不仅图像,而且E尝试使用均方误差(MSE)重建样式图。D尝试对从高斯分布生成的真实图像和伪图像进行分类。最后,我们针对域内属性利用域指导的损失[62]。E试图通过与D竞争来重建更逼真的图像,使投影样式图更适合于图像编辑。如果我们删除任何损失函数,则会降低生成和编辑性能。有关每种损失函数(§D)和联合学习(§B)的效果,请参阅补充材料。 还涉及进一步的培训细节(§C)。
在这里插入图片描述
在这里插入图片描述

本地编辑

如图4底部所示,本地编辑的目标是相对于遮罩将参考图像的某些部分移植到原始图像中,该遮罩指示要修改的区域。请注意,遮罩可以是任何形状,允许使用语义分割方法进行交互式编辑或基于标签的编辑。
在这里插入图片描述
与SPADE [42]或SEAN [65]相反,即使是8×8的粗糙口罩也可以产生真实的图像,从而减轻了用户提供详细口罩的负担。对于两个图像的遮罩不同,可以进一步修改此操作(第4.6节)。

实验

我们提出的方法可以将图像实时有效地投影到样式空间中,并有效地操纵真实图像的特定区域。我们首先描述实验设置(第4.1节)和评估指标(第4.2节),并说明样式图的拟议空间尺寸如何影响图像投影和生成质量(第4.3节)。然后,我们将我们的方法与有关真实图像投影(第4.4节)和本地编辑(第4.5节)的最新方法进行比较。最后,我们展示了一种更加灵活的编辑方案以及我们提出的方法的有效性(第4.6节)。请参阅有关高分辨率实验(§B)和其他结果(§E)的补充材料,例如随机生成,样式混合,语义处理和失败案例。

实验装置

  • 基准线
    我们将我们的模型与最新的生成模型进行比较,包括StyleGAN2 [26],Image2StyleGAN [1],In-DomainGAN [62],结构化噪声[3],样式编辑[11]和SEAN [65]。我们使用作者提供的正式实现从头开始训练所有基线,直到它们收敛为止。对于基于优化的方法[26、1、62、3、11],我们使用其论文中指定的所有超参数。我们还在补充材料(第E.2节)中定性地将我们的方法与ALAE [45]进行了比较。注意,由于作者尚未发布他们的代码,因此我们没有将我们的方法与Image2StyleGAN ++ [2]和Swap Autoencoder [43]进行比较。
  • 数据集
    我们在CelebA-HQ [24],AFHQ [10]和LSUN Car&Church [59]上评估我们的模型。我们采用CelebA-HQ而不是FFHQ [25],因为CelebA-HQ包含分段掩码,因此我们可以训练SEAN基线并利用这些掩码在语义级别上准确评估本地编辑。AFHQ数据集包含比人脸数据集更广泛的变化,这适合于显示我们模型的一般性。优化方法耗时极长,我们将测试和验证集限制为与In-DomainGAN [62]相同的500张图像。CelebA-HQ,AFHQ和LSUN Car&Church的训练图像数量分别为29K,15K,5.5M和126K。我们以256×256的分辨率训练了所有模型,以便在合理的时间内进行比较,但是我们还在补充材料(§B)中提供了1024×1024 FFHQ的结果。

评估指标

  • Frechet inception distance (FID)
    为了评估图像生成的性能,我们计算了从高斯分布和训练集生成的图像之间的FID [18]。我们将生成的样本数设置为等于训练样本数。 我们使用ImageNet预训练的Inception-V3 [54]进行特征提取。

  • FIDlerp
    为了评估全局操纵性能,我们计算插值样本和训练样本(FIDlerp)之间的FID。为了生成内插样本,我们首先将500张测试图像投影到潜在空间中,然后随机选择成对的潜在向量。然后,我们使用线性内插的潜在向量生成图像,该向量的内插系数在0和1之间随机选择。我们将插值样本的数量设置为等于训练样本的数量。低FIDlerp表示该模型提供了高保真度和各种插值样本。

  • MSE & LPIPS
    为了评估投影质量,我们估计目标图像和重建图像之间的像素级和感知级差异,分别是均方差(MSE)和学习的感知图像斑块相似度(LPIPS)[61]。

  • Average precision (AP)
    为了评估本地编辑图像的质量,我们遵循先前工作[43]的惯例,使用在真实图像和伪图像上训练的二进制分类器来测量平均精度[58]。我们使用Blur + JPEG(0.5)模型和完整图像进行评估。较低的AP表示已处理的图像与真实图像更加难以区分。

  • MSEsrc & MSEref
    为了混合特定的语义,我们通过合并原始图像和参考图像的目标语义蒙版来制作合并的蒙版。MSEsrc和MSEref分别从蒙版外部的原始图像和蒙版内部的参考图像测量均方误差。为了自然地组合它们,图像与目标语义掩码相似度进行配对。为了在CelebA-HQ上进行本地编辑比较,每种语义(例如背景,头发)配对了250组测试图像[35],总共产生了2500张图像。为了在AFHQ上进行本地编辑,将250组测试图像随机配对,并在水平和垂直半掩膜之间选择掩膜,从而产生250幅图像。

样式图分辨率的影响

要使用生成模型处理图像,我们首先需要将图像准确地投影到其潜在空间中。在表2中,我们改变了样式图的空间分辨率,并比较了重建和生成的性能。为了公平起见,我们在训练StyleGAN2生成器之后训练编码器模型。随着空间分辨率的提高,重建精度将显着提高。它表明我们的具有空间尺寸的样式图对于图像投影非常有效。FID在数据集中的差异可能不同,这可能是由于一代之间位置之间的上下文关系不同。请注意,我们的具有空间分辨率的方法可以准确保留小的细节,例如,眼睛不会模糊。
在这里插入图片描述

接下来,我们评估样式图分辨率在编辑场景中的效果,将一幅图像的特定部分与另一幅图像混合。图5显示8×8样式图在无缝性方面保持了最合理的图像,并保留了原始图像和参考图像的身份。我们看到,当空间分辨率高于8×8时,容易检测到编辑部分。
在这里插入图片描述
此外,我们在CelebA-HQ的不同分辨率模型中估计FIDlerp。8×8模型显示出比其他分辨率模型最佳的FIDlerp值(9.97); 4×4、16×16和32×32分别为10.72、11.05和12.10。我们假设样式图的分辨率越高,编码器投影的潜在可能性越有可能脱离潜在空间,而潜在空间来自标准的高斯分布。考虑到编辑质量和FIDlerp,我们选择8×8分辨率作为最佳模型,并将其始终用于所有后续实验。

真实影像投影

在表3中,我们将我们的方法与最先进的真实图像投影方法进行了比较。对于这两个数据集,StyleImageGAN的重建质量(MSE和LPIPS)都优于Image2StyleGAN以外的所有竞争对手。但是,Image2StyleGAN无法满足编辑要求,因为它会在潜在的插值(FIDlerp和图形)中生成伪造的伪像,并且会花费几分钟的运行时间。我们的方法还获得了最佳的FIDlerp,这隐式表明我们对样式空间的操作可以生成最逼真的图像。重要的是,我们的方法比基于优化的基准运行速度至少快100倍,这是因为通过样式映射(在单个V100 GPU中进行了测量),单次前馈传递可以提供准确的投影。SEAN还以单次前馈方式运行,但是在训练和测试时都需要使用真实的分段蒙版,这在实际使用中是一个严重的缺点。
在这里插入图片描述

本地编辑

我们从三个方面评估本地编辑性能:可检测性,对蒙版中参考图像的忠实度以及在蒙版外保留原始图像。图6和7直观地表明,我们的方法无缝地将两个图像组合在一起,而其他图像则很难。由于没有用于评估后两个方面的指标,因此我们提出了两个定量指标:MSEsrc和MSEref。表4显示,我们的方法得出的结果对于分类器来说最难检测到假货,并且原始图像和参考图像都得到了最好的体现。请注意,MSE并不是唯一的措施,但是应该结合考虑AP以实现图像的真实性。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

非对齐移植

这里,我们演示了一个更灵活的用例,未对齐移植(图像混合),表明我们的局部编辑不需要对齐原始图像和参考图像上的遮罩。我们将图像投影到样式图上,并用参考样式图的裁剪替换原始样式图的指定区域,即使它们位于不同的位置。用户可以指定要替换的内容。 图8显示了LSUN汽车和教堂的示例。
在这里插入图片描述

讨论和结论

GAN的可逆性对于实际应用无条件GAN模型编辑真实图像至关重要,但尚未得到正确的答案。为了实现此目标,我们提出了StyleMapGAN,它向潜伏空间引入了显式的空间尺寸,称为样式图。通过广泛的评估,我们证明了基于样式图的方法比以前的方法具有许多优势。它可以实时将真实图像准确地投影到潜在空间中,并通过插值和局部编辑来合成高质量的输出图像。我们相信,通过将我们的潜在表示形式应用于其他方法(例如条件GAN(例如BigGAN [6])或变体自动编码器[32])来提高保真度将是令人兴奋的未来工作。

w +空间中的本地编辑

本节说明了如何使用StyleMapGAN执行本地编辑。尽管我们已经在本文的第3.3节中介绍了局部编辑方法,但是由于粗略的蒙版分辨率(8×8),无法进行详细的编辑。与先前的方法相反,我们提出了w +空间中的局部编辑方法。无论样式图(w)的分辨率如何,我们都可以使用高分辨率的调整大小的样式图(w +)来利用详细的蒙版。

图9显示了w+空间上混合的概述。
在这里插入图片描述

在高分辨率数据集中进行实验

我们在FFHQ上以1024×1024分辨率评估模型。基线是StyleGAN2,我们还测试了Image2StyleGAN(A)。使用StyleGAN2官方预培训网络进行公平比较。StyleMapGAN对高分辨率数据集采用32×32样式图,而对于256×256图像则采用8×8样式图。StyleMapGAN-Light(E)是StyleMapGAN的简化版; 它减少了发生器的参数数量。另一种训练设置(D)是顺序学习,它首先训练生成器,然后训练编码器。在表5中,我们使用与论文相同的协议来计算MSE,LPIPS和FIDlerp。FFHQ的训练图像数量为69K,我们将测试和验证集限制为500张图像。

  • 与基线比较
    如表5所示,Image2StyleGAN可以很好地重建图像,但是插值质量很差。FIDlerp低,插值结果不可靠以及运行时间过长,表明Image2StyleGAN不适合图像编辑任务。StyleMapGAN在所有指标上均优于基准,甚至StyleMapGAN-Light也显示出惊人的结果。
    在这里插入图片描述

  • StyleMapGAN-Light
    StyleMapGAN-Light比原始版本小2.5倍。样式图调整器占据了网络的很大一部分,因此我们减少了样式图调整器中要素图的通道数量。重建图像缺少一些细节,但是StyleMapGAN-Light仍然胜过基线,并且FIDlerp甚至比原始版本更好。请参阅我们的代码以参考频道数。

  • 联合学习
    联合学习在训练StyleMapGAN时很重要。它使训练稳定,网络性能更好。训练生成器后训练编码器无法重建图像。我们推测联合学习优于顺序学习的原因如下。在联合学习中,生成器和编码器相互影响。生成器生成易于由编码器重建的图像。编码器的结构是一叠卷积层,这使投影的样式图易于具有局部对应性:样式图中的部分更改会导致对图像进行局部编辑。通过联合学习,生成器中的映射网络还学习使高斯分布的样式图具有局部对应性。

实施细节

  • 构建
    我们遵循StyleGAN2 [26]关于鉴别器体系结构和综合网络卷积层中的特征图计数。我们的映射网络是一个MLP,由八个完全连接的层和一个整形层组成。通道大小为64,最后一个为4,096。我们的编码器采用鉴别器架构,直到8×8层,并且没有小批量鉴别[46]。
  • 训练
    我们共同训练生成器、编码器和鉴别器。如§B中所述,与分别训练对抗网络和编码器相比,它更简单并导致更稳定的训练和更高的性能。对于其余部分,我们大多遵循StyleGAN2的设置,例如,鉴别器体系结构,使用γ= 10的鉴别器中的R1正则化[40],学习率为0.002,β1= 0.0和β2= 0.99的Adam [30]优化器, 生成器和编码器的指数移动平均值,泄漏的ReLU [38],所有层的学习率均等[24],随机水平翻转进行增强,并且将映射网络的学习率降低两个数量级[25]。我们的代码基于StyleGAN21的非官方PyTorch实现。使用2个Tesla V100 GPU(最小批量大小为16)在5M图像上对所有256×256的StyleMapGAN变体进行了为期两周的训练。在§B中,使用8个Tesla V100 GPU(最小批量为16个)对1024×1024模型进行了为期一周的2.5M图像训练。我们注意到,大多数情况会持续缓慢改善,直到获得1000万张图片为止。我们的代码可在网上公开获得,以提高可重复性。
  • 样式图的映射网络设计
    设计映射网络时,有几种选择。我们可以删除映射网络,以便我们的方法不会从标准的高斯分布中生成图像,而仅使用真实图像进行训练,例如自动编码器[19]。如图10所示,自动编码器无法使用投影的样式图生成逼真的图像。似乎是在RGB空间上的两个图像之间复制和粘贴。自动编码器仅使用图像作为输入,它是一个离散变量。相反,我们的方法不仅使用图像,还使用高斯分布的潜像,这是一个连续的空间。如果我们混合使用两个潜在代码来编辑图像,则与离散潜在空间相比,使用连续潜在空间进行训练可以覆盖更多潜在值。
    在这里插入图片描述
    另外,由于样式图的空间尺寸,我们可以轻松想到卷积层。但是,具有卷积层的映射网络在重建中很费力,因此编辑的结果图像与原始图像有很大的不同。我们假设存在这种限制,因为卷积层的映射仅限于局部区域。另一方面,MLP中的每个砝码和输入均已完全连接,因此可以提供更灵活的潜在空间。

损失明细

在主要论文的第3.2节中,我们简要介绍了六种损失。在本节中,我们提供了损失及其责任的详细信息。某些损失会降低重建质量(MSE,LPIPS [61]),但我们需要每一个损失才能获得最佳的编辑质量(FIDlerp)。通过全力以赴的训练,我们可以获得最佳的FIDlerp。表6显示了消融研究的定量结果。 所有损失项的系数均设置为1。
在这里插入图片描述

  • 对抗损失
    鉴别器试图将伪图像分类为伪图像,伪图像是根据高斯分布或输入图像的重构随机生成的。相反,生成器通过生成更逼真的图像来欺骗鉴别器。就平滑插值而言,来自连续空间的生成增加了生成能力。没有与映射网络有关的对抗性损失,我们将无法获得§C中提到的平滑的潜在空间流形。如果不使用对抗损失,图11还显示了不自然的插值结果和棋盘伪像。我们使用非饱和损失[16]作为对抗损失。
    在这里插入图片描述
  • 领域指导的损失
    域引导丢失是由In-DomainGAN [62]引入的。我们通过编码器和生成器对真实图像生成的伪图像进行对抗训练。鉴别器试图将生成的图像分类为伪造,而编码器和生成器试图欺骗鉴别器。丢失导致投影的潜在代码保留在GAN的原始潜在空间中,从而通过利用GAN的属性(例如,平滑插值)促进平滑的真实图像编辑。没有域引导的损耗,插值结果将变得模糊,如图11所示。
  • 潜在重建损失
    编码器的目标是找到生成目标图像的潜在代码。当我们根据高斯分布生成伪图像时,我们就知道了一对潜在代码和所生成的图像。在这种监督下,我们像其他方法一样训练编码器[56、51、45、62]。编码器尝试在原始潜在空间的语义域中投影图像,并减轻对像素级重构的强烈偏见。
  • 图像重建损失
    为了使输出图像在视觉上与输入图像相同,我们在像素级别上将它们之间的差异最小化。如果我们不使用此损失函数,则视觉重建会像ALAE [45]那样失败。
  • 感性损失
    图像重建损失通常会使编码器过拟合并输出模糊图像。几种方法[1、2、62]采用了知觉损失[23],该方法利用了VGG [50]提取的特征进行知觉水平重建。我们将LPIPS [61]用于感知损失,它具有更好的特征表示。
  • R1正则化
    R1正则化[40]使训练稳定。 我们发现懒惰的正则化[26]足够了,并每16步将其应用于鉴别器。没有此损失功能,所有指标的性能都会下降。

附加结果

在本节中,我们将显示广泛的定性结果。§E.1说明了随机生成的图像,以表明我们的方法的生成能力与基线相比不会退化。§E.2和§E.3分别提供了重构和本地编辑的扩展比较。第E.4节显示了其他未对齐的移植示例。我们的方法适用于§E.5和§E.6中所示的其他基于潜伏的编辑方法。最后,我们讨论了该方法的局限性(第E.7节)。

随机产生

GAN的主要目标是根据随机的高斯噪声生成高保真图像。我们显示每个数据集的随机生成结果:CelebA-HQ [24],AFHQ [10]和LSUN Car&Church [59]。除AFHQ之外,我们使用8×8分辨率的样式图,在这种情况下16×16分辨率可提供更好的生成质量,如主文件表2所示。为了生成高质量的图像,我们使用ψ= 0.7的截断技巧[6,31,39]。图12显示了未标定的图像和FID值。在CelebA-HQ和AFHQ中,我们使用与主要实验相同的FID协议; 生成的样本数量等于训练样本的数量。另一方面,LSUN由很多训练图像组成,因此我们使用从训练集中随机选择的50k图像; 生成的样本数也为50k。低FID显示我们的方法具有令人满意的生成能力。

图像投影和插值

尽管基于编码器的方法将图像实时投影到潜在空间中,但它们的投影质量仍未达到预期。图13显示了我们的方法与其他基于编码器的基线(ALAE [45],In-DomainGAN [62]和SEAN [65])之间的投影质量比较。

图14显示了我们的方法和Image2StyleGAN [1]的投影和插值结果。尽管Image2StyleGAN可以高保真地重建输入图像,但是它在潜插值方面很费劲,因为它在w +上的投影偏离了生成器的学习到的潜空间w。

本地编辑

图15显示了与竞争对手的本地编辑比较。由于结果不佳,我们淘汰了两个竞争对手(结构化噪声[3]和样式编辑[11]),如主要论文中的图4所示。这是因为它们不是针对编辑真实图像,而是针对伪图像。

非对齐移植

图16和17显示了LSUN Car&Church [59]中未对齐的移植结果。我们的方法可以将参考图像中任意数量和位置的区域移植到原始图像中。请注意,我们的方法会针对原始图像调整相同参考的色调和结构。

语义操纵

我们利用InterFaceGAN [47]在潜在空间中找到语义边界。我们的方法可以使用从边界得出的特定方向来更改语义属性。我们将方向应用于样式图(w空间)。图18显示了语义操纵的两个版本。全局版本是操纵属性的典型方法。由于样式表的空间尺寸,本地版本仅在我们的方法中可用。我们将语义方向应用于w空间中的指定位置。它使我们不必更改原始图像的不希望区域,而无需考虑属性相关性。例如,“ Rosy Cheeks”使嘴唇变红,而“ Goatee”在全局版本中改变鼻子的颜色,而在本地版本中改变鼻子的颜色,如图18所示。此外,我们可以更改“重妆”中的唇妆和“山羊胡”中的胡须等属性的一部分。它减轻了高度颗粒化标签的工作量。交换自动编码器[43]显示了区域编辑,该结构代码也可以在本地进行操作。但是,由于纹理代码中缺少空间尺寸,因此无法在与颜色和纹理有关的某些属性(例如“苍白的皮肤”)上应用区域编辑。

风格混合

StyleGAN [25]提出了一种样式混合方法,该方法从参考图像中复制样式的指定子集。我们在调整大小的样式图(w +)中进行样式混合。与StyleGAN不同,我们的生成器在调整后的低分辨率样式图(8×8)中具有颜色和纹理信息。另一方面,它通过其他分辨率 ( 1 6 2 − 25 6 2 ) (16^{2}-256^{2}) (1622562)生成整体结构。如果要从参考图像中获取颜色和纹理样式,请通过参考替换8×8调整大小的样式图。图19显示了示例。

使用样式混合和未对齐的移植,我们只能转移局部结构,如图20所示。我们在第一个调整大小的样式图上使用原始图像,在目标区域中的其余分辨率上使用参考图像。
在这里插入图片描述

失败案例

当原始图像和参考图像具有不同的姿势和目标语义大小时,我们的方法存在局限性。图21显示了不同姿势下的失败案例。特别是,头发插值不流畅。图22显示了不同目标语义大小下的失败案例。保险杠的大小和姿势各不相同,我们的方法无法自然移植。当样式图的分辨率提高时,此限制会变得更糟。解决该问题将是未来的有趣工作。
在这里插入图片描述
在这里插入图片描述

代码解读

generate.py代码调试

  1. 加载模型
if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--mixing_type",
        choices=[
            "local_editing",
            "transplantation",
            "w_interpolation",
            "reconstruction",
            "stylemixing",
            "random_generation",
        ],
        required=True,
    )
    parser.add_argument("--ckpt", metavar="CHECKPOINT", required=True)
    parser.add_argument("--test_lmdb", type=str)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--save_image_dir", type=str, default="expr")

    # Below argument is needed for local editing.
    parser.add_argument(
        "--local_editing_part",
        type=str,
        default=None,
        choices=[
            "nose",
            "hair",
            "background",
            "eye",
            "eyebrow",
            "lip",
            "neck",
            "cloth",
            "skin",
            "ear",
        ],
    )

    args = parser.parse_args()

    ckpt = torch.load(args.ckpt)
    train_args = ckpt["train_args"]

args中的变量
在这里插入图片描述
train_args中的变量
在这里插入图片描述

  1. train_args命名空间中的变量添加到args命名空间中的变量中。
...
    for key in vars(train_args):
        if not (key in vars(args)):
            setattr(args, key, getattr(train_args, key))
    print(args)

输出打印args命名空间中变量

# 输出打印变量
Namespace(batch=1, batch_per_gpu=8, channel_multiplier=2, ckpt='expr/checkpoints/celeba_hq_256_8x8.pt', d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, local_editing_part='cloth', lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', mixing_type='local_editing', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=2, r1=10, remove_indomain=False, remove_w_rec=False, save_image_dir='expr', size=256, small_generator=False, start_iter=0, test_lmdb='data/celeba_hq/LMDB_test', train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')

在这里插入图片描述

  1. 调用Model()
    ...
    dataset_name = args.dataset
    args.save_image_dir = os.path.join(
        args.save_image_dir, args.mixing_type, dataset_name
    )

    model = Model().to(device)
  1. 进入Model类中的初始化函数中,调用training.model.py中的Generator()类中的初始化函数。
# generate.py
class Model(nn.Module):
    def __init__(self, device="cuda"):
        super(Model, self).__init__()
        self.g_ema = Generator(
            args.size,
            args.mapping_layer_num,
            args.latent_channel_size,
            args.latent_spatial_size,
            lr_mul=args.lr_mul,
            channel_multiplier=args.channel_multiplier,
            normalize_mode=args.normalize_mode,
            small_generator=args.small_generator,
        )
  1. 进入training.model.py中的Generator()类中的初始化函数,调用training.model.py中的PixelNorm()类中的初始化函数。
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        super().__init__()

        self.latent_spatial_size = latent_spatial_size
        self.style_dim = style_dim

        layers = [PixelNorm()]

输入参数
在这里插入图片描述

  1. 进入training.model.py中的PixelNorm()类中的初始化函数。
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
  1. 返回training.model.py中的Generator()类中的初始化函数,调用training.model.py中的EqualLinear类的初始化函数。
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        ...
        for i in range(mapping_layer_num):
        if i != (mapping_layer_num - 1):
            in_channel = style_dim
            out_channel = style_dim
        else:
            in_channel = style_dim
            out_channel = style_dim * latent_spatial_size * latent_spatial_size

        layers.append(
            EqualLinear(
                in_channel, out_channel, lr_mul=lr_mul, activation="fused_lrelu"
            )
        )
  1. 进入training.model.py中的EqualLinear类的初始化函数。
class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

这里是引用
torch.nn.Parameter是继承自torch.Tensor的子类,其主要作用是作为nn.Module中的可训练参数使用。它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去;而module中非nn.Parameter()的普通tensor是不在parameter中的。
注意到,nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torth.Tensor对象的默认值相反。
在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的。

  1. 返回training.model.py中的Generator()类中的初始化函数,完成layers创建,调用training.model.py中的Decoder类中的初始化函数。
    在这里插入图片描述
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        ...
        self.mapping_z = nn.Sequential(*layers)

        self.decoder = Decoder(
            size,
            style_dim,
            latent_spatial_size,
            channel_multiplier=channel_multiplier,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
            lr_mul=1,
            small_generator=small_generator,
        )  # always 1, always zero padding
  1. 进入training.model.py中的Decoder类中的初始化函数,调用training.model.py中的ConstantInput类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        super().__init__()

        self.size = size

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.log_size = int(math.log(size, 2))

        self.input = ConstantInput(
            channels[latent_spatial_size], size=latent_spatial_size
        )

输入参数变量
在这里插入图片描述

  1. 进入training.model.py中的ConstantInput类中的初始化函数。
class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, batch):
        out = self.input.repeat(batch, 1, 1, 1)

        return out
  1. 返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的StyledConv类的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
                if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = channels[latent_spatial_size]

        self.conv1 = StyledConv(
            channels[latent_spatial_size],
            channels[latent_spatial_size],
            3,
            stylecode_dim,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )
  1. 进入training.model.py中的StyledConv类的初始化函数,调用training.model.py中的ModulatedConv2d类的初始化函数。
class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        blur_kernel,
        normalize_mode,
        upsample=False,
        activate=True,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )
  1. 进入training.model.py中的ModulatedConv2d类的初始化函数,调用training.model.py中的EqualConv2d类的初始化函数。
class ModulatedConv2d(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        normalize_mode,
        blur_kernel,
        upsample=False,
        downsample=False,
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.normalize_mode = normalize_mode
        if normalize_mode == "InstanceNorm2d":
            self.norm = nn.InstanceNorm2d(in_channel, affine=False)
        elif normalize_mode == "BatchNorm2d":
            self.norm = nn.BatchNorm2d(in_channel, affine=False)

        self.beta = None

        self.gamma = EqualConv2d(
            style_dim,
            in_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            bias=True,
            bias_init=1,
        )
        self.beta = EqualConv2d(
            style_dim,
            in_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            bias=True,
            bias_init=0,
        )

  1. 进入training.model.py中的EqualConv2d类的初始化函数。
class EqualConv2d(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        stride=1,
        padding=0,
        lr_mul=1,
        bias=True,
        bias_init=0,
        conv_transpose2d=False,
        activation=False,
    ):
        super().__init__()

        self.out_channel = out_channel
        self.kernel_size = kernel_size

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size).div_(lr_mul)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) * lr_mul

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init))

            self.lr_mul = lr_mul
        else:
            self.lr_mul = None

        self.conv_transpose2d = conv_transpose2d

        if activation:
            self.activation = ScaledLeakyReLU(0.2)
            # self.activation = FusedLeakyReLU(out_channel)
        else:
            self.activation = False

输入参数变量
在这里插入图片描述

  1. 返回training.model.py中的ModulatedConv2d类的初始化函数,调用training.model.py中的EqualConv2d类的初始化函数。执行完后返回ModulatedConv2d类的初始化函数,执行完后返回StyledConv类的初始化函数, 调用training.model.op.fused_act.py中的FusedLeakyReLU类的初始化函数。
class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        blur_kernel,
        normalize_mode,
        upsample=False,
        activate=True,
    ):
        ...
        if activate:
            self.activate = FusedLeakyReLU(out_channel)
        else:
            self.activate = None
  1. 进入training.model.op.fused_act.py中的FusedLeakyReLU类的初始化函数。
class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        self.bias = nn.Parameter(torch.zeros(channel))
        self.negative_slope = negative_slope
        self.scale = scale
  1. 返回StyledConv类的初始化函数,执行完最后一句,返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的ConvLayer类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
        in_channel = channels[latent_spatial_size]

        self.start_index = int(math.log(latent_spatial_size, 2)) + 1  # if 4x4 -> 3
        self.convs = nn.ModuleList()
        self.convs_latent = nn.ModuleList()

        self.convs_latent.append(
            ConvLayer(
                style_dim, stylecode_dim, 3, bias=True, activate=True, lr_mul=lr_mul
            )
        self.convs_latent.append(
            ConvLayer(
                stylecode_dim, stylecode_dim, 3, bias=True, activate=True, lr_mul=lr_mul
            )
        )
  1. 进入training.model.py中的ConvLayer类中的初始化函数。
class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
        lr_mul=1,
    ):
        assert not (upsample and downsample)
        layers = []

        if upsample:
            stride = 2
            self.padding = 0
            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                    conv_transpose2d=True,
                    lr_mul=lr_mul,
                )
            )

            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))

        else:

            if downsample:
                factor = 2
                p = (len(blur_kernel) - factor) + (kernel_size - 1)
                pad0 = (p + 1) // 2
                pad1 = p // 2

                layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

                stride = 2
                self.padding = 0

            else:
                stride = 1
                self.padding = kernel_size // 2

            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                )
            )

        if activate:
            if bias:
                layers.append(FusedLeakyReLU(out_channel))

            else:
                layers.append(ScaledLeakyReLU(0.2))

        super().__init__(*layers)

输入参数
在这里插入图片描述

  1. 返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的StyledResBlock类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
        for i in range(self.start_index, self.log_size + 1):  # 8x8~ 128x128
            if small_generator:
                stylecode_dim_prev, stylecode_dim_next = style_dim, style_dim
            else:
                stylecode_dim_prev = channels[2 ** (i - 1)]
                stylecode_dim_next = channels[2 ** i]
            self.convs_latent.append(
                ConvLayer(
                    stylecode_dim_prev,
                    stylecode_dim_next,
                    3,
                    upsample=True,
                    bias=True,
                    activate=True,
                    lr_mul=lr_mul,
                )
            )
            self.convs_latent.append(
                ConvLayer(
                    stylecode_dim_next,
                    stylecode_dim_next,
                    3,
                    bias=True,
                    activate=True,
                    lr_mul=lr_mul,
                )
            )

        if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = None

        for i in range(self.start_index, self.log_size + 1):
            out_channel = channels[2 ** i]
            self.convs.append(
                StyledResBlock(
                    in_channel,
                    out_channel,
                    stylecode_dim,
                    blur_kernel,
                    normalize_mode=normalize_mode,
                )
            )

            in_channel = out_channel
  1. 进入training.model.py中的StyledResBlock类中的初始化函数,。
class StyledResBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        style_dim,
        blur_kernel,
        normalize_mode,
        global_feature_channel=None,
    ):
        super().__init__()

        if style_dim is None:
            if global_feature_channel is not None:
                self.conv1 = StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    in_channel + global_feature_channel,
                    blur_kernel=blur_kernel,
                    upsample=True,
                    normalize_mode=normalize_mode,
                )
                self.conv2 = StyledConv(
                    out_channel,
                    out_channel,
                    3,
                    out_channel + global_feature_channel,
                    blur_kernel=blur_kernel,
                    normalize_mode=normalize_mode,
                )
            else:
                self.conv1 = StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    in_channel,
                    blur_kernel=blur_kernel,
                    upsample=True,
                    normalize_mode=normalize_mode,
                )
                self.conv2 = StyledConv(
                    out_channel,
                    out_channel,
                    3,
                    out_channel,
                    blur_kernel=blur_kernel,
                    normalize_mode=normalize_mode,
                )
        else:
            self.conv1 = StyledConv(
                in_channel,
                out_channel,
                3,
                style_dim,
                blur_kernel=blur_kernel,
                upsample=True,
                normalize_mode=normalize_mode,
            )
            self.conv2 = StyledConv(
                out_channel,
                out_channel,
                3,
                style_dim,
                blur_kernel=blur_kernel,
                normalize_mode=normalize_mode,
            )

        self.skip = ConvLayer(
            in_channel, out_channel, 1, upsample=True, activate=False, bias=False
        )

输入参数
在这里插入图片描述

  1. 返回training.model.py中的Decoder类中的初始化函数,执行完后面的代码。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
                if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = channels[size]

        # add adain to to_rgb
        self.to_rgb = StyledConv(
            channels[size],
            3,
            1,
            stylecode_dim,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )

        self.num_stylecodes = self.log_size * 2 - 2 * (
            self.start_index - 2
        )  # the number of AdaIN layer(stylecodes)
        assert len(self.convs) * 2 + 2 == self.num_stylecodes

        self.latent_spatial_size = latent_spatial_size
  1. 返回training.model.py中的Generator()类中的初始化函数,返回Model类中的初始化函数中,调用training.model.py中的Encoder类中的初始化函数。
class Model(nn.Module):
    def __init__(self, device="cuda"):
        ....
            self.e_ema = Encoder(
            args.size,
            args.latent_channel_size,
            args.latent_spatial_size,
            channel_multiplier=args.channel_multiplier,
        )
  1. 进入training.model.py中的Encoder类中的初始化函数,调用training.model.py中的ResBlock类初始化函数。
class Encoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()

        channels = {
            1: 512,
            2: 512,
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.from_rgb = ConvLayer(3, channels[size], 1)
        self.convs = nn.ModuleList()

        log_size = int(math.log(size, 2))
        self.log_size = log_size

        in_channel = channels[size]
        end = int(math.log(latent_spatial_size, 2))

        for i in range(self.log_size, end, -1):
            out_channel = channels[2 ** (i - 1)]

            self.convs.append(
                ResBlock(in_channel, out_channel, blur_kernel, return_features=True)
            )

输入参数
在这里插入图片描述

  1. 进入training.model.py中的ResBlock类初始化函数。
class ResBlock(nn.Module):
    def __init__(
        self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], return_features=False
    ):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )
        self.return_features = return_features
  1. 返回training.model.py中的Encoder类中的初始化函数,执行后面的代
    码。
class Encoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel=[1, 3, 3, 1],
    ):
        ...
            in_channel = out_channel

        self.final_conv = ConvLayer(in_channel, style_dim, 3)
  1. 返回Model类中的初始化函数中,返回generate.py中的主程序中。继续执行,调用training.dataset.py中的GTMaskDataset("data/celeba_hq", transform, args.size)类初始化函数。
    model.g_ema.load_state_dict(ckpt["g_ema"])
    model.e_ema.load_state_dict(ckpt["e_ema"])
    model.eval()

    batch = args.batch

    device = "cuda"
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    if args.mixing_type == "random_generation":
        os.makedirs(args.save_image_dir, exist_ok=True)
    elif args.mixing_type in [
        "w_interpolation",
        "reconstruction",
        "transplantation",
        "stylemixing",
    ]:
        os.makedirs(args.save_image_dir, exist_ok=True)
        dataset = MultiResolutionDataset(args.test_lmdb, transform, args.size)
    elif args.mixing_type == "local_editing":

        if dataset_name == "afhq":
            args.save_image_dir = os.path.join(args.save_image_dir)
            for kind in [
                "mask",
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)
        else:  # celeba_hq
            args.save_image_dir = os.path.join(
                args.save_image_dir,
                args.local_editing_part,
            )
            for kind in [
                "mask",
                "mask_ref",
                "mask_src",
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)
            mask_path_base = f"data/{dataset_name}/local_editing"

        # GT celeba_hq mask images
        if dataset_name == "celeba_hq":
            assert "celeba_hq" in args.test_lmdb

            dataset = GTMaskDataset("data/celeba_hq", transform, args.size)
  1. 进入training.dataset.py中的GTMaskDataset("data/celeba_hq", transform, args.size)类初始化函数。
class GTMaskDataset(Dataset):
    def __init__(self, dataset_folder, transform, resolution=256):

        self.env = lmdb.open(
            f"{dataset_folder}/LMDB_test",
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test")

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))

        self.resolution = resolution
        self.transform = transform

        # convert filename to celeba_hq index
        CelebA_HQ_to_CelebA = (
            f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt"
        )
        CelebA_to_CelebA_HQ_dict = {}

        original_test_path = f"{dataset_folder}/raw_images/test/images"
        mask_label_path = f"{dataset_folder}/local_editing/GT_labels"

        with open(CelebA_HQ_to_CelebA, "r") as fp:
            read_line = fp.readline()
            attrs = re.sub(" +", " ", read_line).strip().split(" ")
            while True:
                read_line = fp.readline()

                if not read_line:
                    break

                idx, orig_idx, orig_file = (
                    re.sub(" +", " ", read_line).strip().split(" ")
                )

                CelebA_to_CelebA_HQ_dict[orig_file] = idx

        self.mask = []

        for filename in os.listdir(original_test_path):
            CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename]
            CelebA_HQ_filename = CelebA_HQ_filename + ".png"
            self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename))

在这里插入图片描述

  1. 返回generate.py中的主程序中,继续执行,调用data_sampler(dataset, shuffle=False)函数。
            parts_index = {
                "background": [0],
                "skin": [1],
                "eyebrow": [6, 7],
                "eye": [3, 4, 5],
                "ear": [8, 9, 15],
                "nose": [2],
                "lip": [10, 11, 12],
                "neck": [16, 17],
                "cloth": [18],
                "hair": [13, 14],
            }

        # afhq, coarse(half-and-half) masks
        else:
            assert "afhq" in args.test_lmdb and "afhq" == dataset_name
            dataset = MultiResolutionDataset(args.test_lmdb, transform, args.size)

    if args.mixing_type in [
        "w_interpolation",
        "reconstruction",
        "stylemixing",
        "local_editing",
    ]:
        n_sample = len(dataset)
        sampler = data_sampler(dataset, shuffle=False)

这里是引用

  1. 进入data_sample(dataset, shuffle)函数
def data_sampler(dataset, shuffle):
    if shuffle:
        return data.RandomSampler(dataset)
    else:
        return data.SequentialSampler(dataset)
  1. 返回主程序。
        loader = data.DataLoader(
            dataset,
            batch,
            sampler=sampler,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
        )

        # generated images should match with n sample
        if n_sample % batch == 0:
            assert len(loader) == n_sample // batch
        else:
            assert len(loader) == n_sample // batch + 1

        total_latents = torch.Tensor().to(device)
        real_imgs = torch.Tensor().to(device)

        if args.mixing_type == "local_editing":
            if dataset_name == "afhq":
                masks = (
                    -2 * torch.ones(n_sample, args.size, args.size).to(device).float()
                )

                mix_type = list(range(n_sample))
                random.shuffle(mix_type)
                horizontal_mix = mix_type[: n_sample // 2]
                vertical_mix = mix_type[n_sample // 2 :]

                masks[horizontal_mix, :, args.size // 2 :] = 2
                masks[vertical_mix, args.size // 2 :, :] = 2
            else:
                masks = torch.Tensor().to(device).long()

    with torch.no_grad():
        if args.mixing_type == "random_generation":
            truncation = 0.7
            truncation_sample = 5000
            truncation_mean_latent = torch.Tensor().to(device)
            for _ in range(truncation_sample // batch):
                z = make_noise(batch, args.latent_channel_size, device)
                partial_mean_latent = model(z, mode="calculate_mean_stylemap")
                truncation_mean_latent = torch.cat(
                    [truncation_mean_latent, partial_mean_latent], dim=0
                )
            truncation_mean_latent = truncation_mean_latent.mean(0, keepdim=True)

            # refer to stylegan official repository: https://github.com/NVlabs/stylegan/blob/master/generate_figures.py
            cx, cy, cw, ch, rows, lods = 0, 0, 1024, 1024, 3, [0, 1, 2, 2, 3, 3]

            for seed in range(0, 4):
                torch.manual_seed(seed)
                png = f"{args.save_image_dir}/random_generation_{seed}.png"
                print(png)

                total_images_len = sum(rows * 2 ** lod for lod in lods)
                total_images = torch.Tensor()

                while total_images_len > 0:
                    num = batch if total_images_len > batch else total_images_len
                    z = make_noise(num, args.latent_channel_size, device)
                    total_images_len -= batch

                    images = model(
                        (z, truncation, truncation_mean_latent),
                        mode="random_generation",
                    )

                    images = images.permute(0, 2, 3, 1)
                    images = images.cpu()
                    total_images = torch.cat([total_images, images], dim=0)

                total_images = torch.clamp(total_images, min=-1.0, max=1.0)
                total_images = (total_images + 1) / 2 * 255
                total_images = total_images.numpy().astype(np.uint8)

                canvas = Image.new(
                    "RGB",
                    (sum(cw // 2 ** lod for lod in lods), ch * rows),
                    "white",
                )
                image_iter = iter(list(total_images))
                for col, lod in enumerate(lods):
                    for row in range(rows * 2 ** lod):
                        image = Image.fromarray(next(image_iter), "RGB")
                        # image = image.crop((cx, cy, cx + cw, cy + ch))
                        image = image.resize(
                            (cw // 2 ** lod, ch // 2 ** lod), Image.ANTIALIAS
                        )
                        canvas.paste(
                            image,
                            (
                                sum(cw // 2 ** lod for lod in lods[:col]),
                                row * ch // 2 ** lod,
                            ),
                        )
                canvas.save(png)

        elif args.mixing_type == "reconstruction":
            for i, real_img in enumerate(tqdm(loader, mininterval=1)):
                real_img = real_img.to(device)
                recon_image = model(real_img, "reconstruction")

                for i_b, (img_1, img_2) in enumerate(zip(real_img, recon_image)):
                    save_images(
                        [img_1, img_2],
                        [
                            f"{args.save_image_dir}/{i*batch+i_b}_real.png",
                            f"{args.save_image_dir}/{i*batch+i_b}_recon.png",
                        ],
                    )

        elif args.mixing_type == "transplantation":

            for kind in [
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)

            # AFHQ
            transplantation_dataset = [
                (62, 271, [((4, 2), (3, 2), 2, 4), ((0, 1), (0, 1), 3, 2)])
            ]

            for index_src, index_ref, coordinates in transplantation_dataset:
                src_img = dataset[index_src].to(device)
                ref_img = dataset[index_ref].to(device)

                mixed_image, recon_img_src, recon_img_ref = model(
                    (src_img, ref_img, coordinates), mode="transplantation"
                )

                ratio = 256 // 8

                src_img = (src_img + 1) / 2
                ref_img = (ref_img + 1) / 2

                colors = [(0, 0, 255), (0, 255, 0), (0, 255, 0)]

                for color_i, (
                    (src_p_y, src_p_x),
                    (ref_p_y, ref_p_x),
                    height,
                    width,
                ) in enumerate(coordinates):
                    for i in range(2):
                        img = src_img if i == 0 else ref_img
                        img = img.cpu()
                        img = transforms.ToPILImage()(img)
                        img = np.asarray(img)
                        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                        if i == 0:
                            img = cv2.rectangle(
                                img,
                                (src_p_x * ratio, src_p_y * ratio),
                                (
                                    (src_p_x + width) * ratio,
                                    (src_p_y + height) * ratio,
                                ),
                                colors[color_i],
                                2,
                            )
                        else:
                            img = cv2.rectangle(
                                img,
                                (ref_p_x * ratio, ref_p_y * ratio),
                                (
                                    (ref_p_x + width) * ratio,
                                    (ref_p_y + height) * ratio,
                                ),
                                colors[color_i],
                                2,
                            )
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        img = transforms.ToTensor()(img)

                        if i == 0:
                            src_img = img
                        else:
                            ref_img = img

                save_images(
                    [mixed_image[0], recon_img_src[0], recon_img_ref[0]],
                    [
                        f"{args.save_image_dir}/synthesized_image/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/source_reconstruction/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/reference_reconstruction/{index_src}_{index_ref}.png",
                    ],
                )

                save_images(
                    [src_img, ref_img],
                    [
                        f"{args.save_image_dir}/source_image/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/reference_image/{index_src}_{index_ref}.png",
                    ],
                    range=(0, 1),
                )

        else:
            for i, real_img in enumerate(tqdm(loader, mininterval=1)):

在这里插入图片描述

  1. 进入training.dataset.py中的GTMaskDataset类中的__getitem__函数中。
class GTMaskDataset(Dataset):
    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)

        mask = Image.open(self.mask[index])

        mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
        mask = transforms.ToTensor()(mask)

        mask = mask.squeeze()
        mask *= 255
        mask = mask.long()

        assert mask.shape == (self.resolution, self.resolution)
        return img, mask
  1. 返回主程序调用Model类的model函数.
                if (args.mixing_type == "local_editing") and (
                    dataset_name == "celeba_hq"
                ):
                    real_img, mask = real_img
                    mask = mask.to(device)
                    masks = torch.cat([masks, mask], dim=0)
                real_img = real_img.to(device)

                latents = model(real_img, "projection")

  1. 进入Model类的forward函数.
class Model(nn.Module):
    def forward(self, input, mode):
        if mode == "projection":
            fake_stylecode = self.e_ema(input)

            return fake_stylecode
  1. 进入Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
        out = self.from_rgb(input)
  1. 进入EqualConv2d类的forward函数.
class EqualConv2d(nn.Module):
    def forward(self, input):
        if self.lr_mul != None:
            bias = self.bias * self.lr_mul
        else:
            bias = None

        if self.conv_transpose2d:
            # group version for fast training
            batch, in_channel, height, width = input.shape
            input_temp = input.view(1, batch * in_channel, height, width)
            weight = self.weight.unsqueeze(0).repeat(batch, 1, 1, 1, 1)
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(
                input_temp,
                weight * self.scale,
                bias=bias,
                padding=self.padding,
                stride=2,
                groups=batch,
            )
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            out = F.conv2d(
                input,
                self.weight * self.scale,
                bias=bias,
                stride=self.stride,
                padding=self.padding,
            )

        if self.activation:
            out = self.activation(out)

        return out
  1. 返回Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
        ...
        for convs in self.convs:
            out, _, _ = convs(out)

        out = self.final_conv(out)

        return out  # spatial style code
  1. 进入EqualConv2d类的forward函数,然后进入ResBlock类的forward函数.
class ResBlock(nn.Module):
    def forward(self, input):
        out1 = self.conv1(input)
ModuleList(
  (0): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(128, 128, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(128, 256, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(128, 256, 1, stride=2, padding=0)
    )
  )
  (1): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(256, 256, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(256, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(256, 512, 1, stride=2, padding=0)
    )
  )
  (2): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
  (3): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
  (4): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
)
  1. 进入EqualConv2d类的forward函数.

  2. 返回ResBlock类的forward函数。

class ResBlock(nn.Module):
    def forward(self, input):
        ...
        out2 = self.conv2(out1)
  1. 进入Blur类的forward函数,调用upfirdn2d(input, self.kernel, pad=self.pad)函数。
class Blur(nn.Module):
    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out
  1. 进入training.op.upfirdn2d.py中的upfirdn2d(input, self.kernel, pad=self.pad)函数。
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    out = UpFirDn2d.apply(
        input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
    )

    return out
  1. 进入UpFirDn2d类的forward函数中。
class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out
  1. 返回Blur类的forward函数。进入EqualConv2d类的forward函数.

  2. 返回ResBlock类的forward函数。

class ResBlock(nn.Module):
    def forward(self, input):
        ...
        skip = self.skip(input)
        out = (out2 + skip) / math.sqrt(2)

        if self.return_features:
            return out, out1, out2
        else:
            return out
  1. 返回Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
         ...
        for convs in self.convs:
            out, _, _ = convs(out)

        out = self.final_conv(out)

        return out  # spatial style code
  1. 返回Model类的forward函数.返回主程序。继续执行。
                total_latents = torch.cat([total_latents, latents], dim=0)
                real_imgs = torch.cat([real_imgs, real_img], dim=0)
                            elif args.mixing_type == "local_editing":
                if dataset_name == "afhq":
                    # change it later
                    indices = list(range(len(total_latents)))
                    random.shuffle(indices)
                    indices1 = indices[: len(total_latents) // 2]
                    indices2 = indices[len(total_latents) // 2 :]

                else:
                    with open(
                        f"{mask_path_base}/celeba_hq_test_GT_sorted_pair.pkl",
                        "rb",
                    ) as f:
                        sorted_similarity = pickle.load(f)

                    indices1 = []
                    indices2 = []
                    for (i1, i2), _ in sorted_similarity[args.local_editing_part]:
                        indices1.append(i1)
                        indices2.append(i2)

在这里插入图片描述
在这里插入图片描述
51. 继续执行主程序,调用模型model((total_latents[index1], total_latents[index2], mask),local_editing",)

            for loop_i, (index1, index2) in tqdm(
                enumerate(zip(indices1, indices2)), total=n_sample
            ):
                if args.mixing_type == "w_interpolation":
                    imgs = model(
                        (total_latents[index1], total_latents[index2]),
                        "w_interpolation",
                    )
                    assert len(imgs) == 1
                    save_image(
                        imgs[0],
                        f"{args.save_image_dir}/{loop_i}.png",
                    )
                elif args.mixing_type == "stylemixing":
                    n_rows = len(index2)
                    coarse_img, fine_img = model(
                        (
                            torch.stack([total_latents[index1] for _ in range(n_rows)]),
                            torch.stack([total_latents[i2] for i2 in index2]),
                        ),
                        "stylemixing",
                    )

                    save_images(
                        [coarse_img, fine_img],
                        [
                            f"{args.save_image_dir}/{index1}_coarse.png",
                            f"{args.save_image_dir}/{index1}_fine.png",
                        ],
                    )

                elif args.mixing_type == "local_editing":
                    src_img = real_imgs[index1]
                    ref_img = real_imgs[index2]

                    if dataset_name == "celeba_hq":
                        mask1_logit = masks[index1]
                        mask2_logit = masks[index2]

                        mask1 = -torch.ones(mask1_logit.shape).to(
                            device
                        )  # initialize with -1
                        mask2 = -torch.ones(mask2_logit.shape).to(
                            device
                        )  # initialize with -1

                        for label_i in parts_index[args.local_editing_part]:
                            mask1[(mask1_logit == label_i) == True] = 1
                            mask2[(mask2_logit == label_i) == True] = 1

                        mask = mask1 + mask2
                        mask = mask.float()
                    elif dataset_name == "afhq":
                        mask = masks[index1]

                    mixed_image, recon_img_src, recon_img_ref = model(
                        (total_latents[index1], total_latents[index2], mask),
                        "local_editing",
                    )
  1. 进入Model类的forward函数,调用Generator类的forward函数。
class Model(nn.Module):
    def forward(self, input, mode):
        ...
        elif mode == "local_editing":
            w1, w2, mask = input
            w1, w2, mask = w1.unsqueeze(0), w2.unsqueeze(0), mask.unsqueeze(0)

            if dataset_name == "celeba_hq":
                mixed_image = self.g_ema(
                    [w1, w2],
                    input_is_stylecode=True,
                    mix_space="w_plus",
                    mask=mask,
                )[0]
  1. 进入Generator类的forward函数,调用decoder(stylecode, mix_space=mix_space, mask=mask)函数。
class Generator(nn.Module):
    def forward(
        self,
        input,
        return_stylecode=False,
        input_is_stylecode=False,
        mix_space=None,
        mask=None,
        calculate_mean_stylemap=False,
        truncation=None,
        truncation_mean_latent=None,
    ):
        if calculate_mean_stylemap:  # calculate mean_latent
            stylecode = self.mapping_z(input)
            return stylecode.mean(0, keepdim=True)
        else:
            if input_is_stylecode:
                stylecode = input
            else:
                stylecode = self.mapping_z(input)
                if truncation != None and truncation_mean_latent != None:
                    stylecode = truncation_mean_latent + truncation * (
                        stylecode - truncation_mean_latent
                    )
                N, C = stylecode.shape
                stylecode = stylecode.reshape(
                    N, -1, self.latent_spatial_size, self.latent_spatial_size
                )

            image = self.decoder(stylecode, mix_space=mix_space, mask=mask)

            if return_stylecode == True:
                return image, stylecode
            else:
                return image, None

在这里插入图片描述

  1. 进入Decoder类的forward函数中。
class Decoder(nn.Module):
   def forward(self, style_code, mix_space=None, mask=None):
       ...
       else:
           batch = style_code[0].shape[0]

       style_codes = []
       ...
       elif mix_space == "w_plus":  # mix stylemaps in W+ space
	       style_code1 = style_code[0]
	       style_code2 = style_code[1]
	       style_codes1 = []
	       style_codes2 = []
	
	       for up_layer in self.convs_latent:
	           style_code1 = up_layer(style_code1)
	           style_code2 = up_layer(style_code2)
	           style_codes1.append(style_code1)
	           style_codes2.append(style_code2)
	
	       for i in range(0, len(style_codes2)):
	           _, C, H, W = style_codes2[i].shape
	           ratio = self.size // H
	           # print(mask)
	           mask_for_latent = nn.MaxPool2d(kernel_size=ratio, stride=ratio)(mask)
	           mask_for_latent = mask_for_latent.unsqueeze(1).repeat(1, C, 1, 1)
	           style_codes2[i] = torch.where(
	               mask_for_latent > -1, style_codes2[i], style_codes1[i]
	           )
	
	       style_codes = style_codes2
	    ....
        out = self.input(batch)
        out = self.conv1(out, style_codes[0])

        for i in range(len(self.convs)):
            out = self.convs[i](out, [style_codes[2 * i + 1], style_codes[2 * i + 2]])
        image = self.to_rgb(out, style_codes[-1])

        return image

self.cons_latent结构

ModuleList(
  (0): ConvLayer(
    (0): EqualConv2d(64, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (1): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (2): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (3): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (4): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (5): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (6): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (7): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (8): ConvLayer(
    (0): EqualConv2d(512, 256, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (9): ConvLayer(
    (0): EqualConv2d(256, 256, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (10): ConvLayer(
    (0): EqualConv2d(256, 128, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (11): ConvLayer(
    (0): EqualConv2d(128, 128, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
)
  1. 返回Generator类的forward函数,返回Model类的forward函数,继续执行后面的函数。
class Model(nn.Module):
    def forward(self, input, mode):
        ...
        
            recon_img_src, _ = self.g_ema(w1, input_is_stylecode=True)
            recon_img_ref, _ = self.g_ema(w2, input_is_stylecode=True)

            return mixed_image, recon_img_src, recon_img_ref
  1. 返回主程序
                    mixed_image, recon_img_src, recon_img_ref = model(
                        (total_latents[index1], total_latents[index2], mask),
                        "local_editing",
                    )

                    save_images(
                        [
                            mixed_image[0],
                            recon_img_src[0],
                            src_img,
                            ref_img,
                            recon_img_ref[0],
                        ],
                        [
                            f"{args.save_image_dir}/synthesized_image/{index1}.png",
                            f"{args.save_image_dir}/source_reconstruction/{index1}.png",
                            f"{args.save_image_dir}/source_image/{index1}.png",
                            f"{args.save_image_dir}/reference_image/{index1}.png",
                            f"{args.save_image_dir}/reference_reconstruction/{index1}.png",
                        ],
                    )

                    mask[mask < -1] = -1
                    mask[mask > -1] = 1

                    save_image(
                        mask,
                        f"{args.save_image_dir}/mask/{index1}.png",
                    )

                    if dataset_name == "celeba_hq":
                        save_images(
                            [mask1, mask2],
                            [
                                f"{args.save_image_dir}/mask_src/{index1}.png",
                                f"{args.save_image_dir}/mask_ref/{index1}.png",
                            ],
                        )

pair_masks.py代码调试

以celeba_hq的人脸分析成分,保存对应的输入图像的交并比。

  1. 主程序执行,进入group_pair_GT()函数。
if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument(
        "--save_dir", type=str, default="../data/celeba_hq/local_editing"
    )

    args = parser.parse_args()
    args.dataset_name = "celeba_hq"
    os.makedirs(args.save_dir, exist_ok=True)
    args.path = f"../data/{args.dataset_name}"
    args.mask_origin = 'GT_test'
    with torch.no_grad():
        # our CelebA-HQ test dataset contains 500 images
        # change this value if you have the different number of GT_labels
        args.n_sample = 5
        group_pair_GT()
  1. 进入group_pair_GT()函数,调用GTMaskDataset(args.path, transform, images_size)函数.
@torch.no_grad()
def group_pair_GT():
    device = "cuda"
    args.n_sample = 5

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    images_size = 256  # you can use other resolution for calculating if your LMDB(args.path) has different resolution.
    dataset = GTMaskDataset(args.path, transform, images_size)

  1. 进入GTMaskDataset(args.path, transform, images_size)函数.
# dataset.py
class GTMaskDataset(Dataset):
    def __init__(self, dataset_folder, transform, resolution=256):

        self.env = lmdb.open(
            f"{dataset_folder}/LMDB_test",
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test")

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))

        self.resolution = resolution
        self.transform = transform

        # convert filename to celeba_hq index
        CelebA_HQ_to_CelebA = (
            f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt"
        )
        CelebA_to_CelebA_HQ_dict = {}

        original_test_path = f"{dataset_folder}/raw_images/test/images"
        mask_label_path = f"{dataset_folder}/local_editing/GT_labels"

        with open(CelebA_HQ_to_CelebA, "r") as fp:
            read_line = fp.readline()
            attrs = re.sub(" +", " ", read_line).strip().split(" ")
            while True:
                read_line = fp.readline()

                if not read_line:
                    break

                idx, orig_idx, orig_file = (
                    re.sub(" +", " ", read_line).strip().split(" ")
                )

                CelebA_to_CelebA_HQ_dict[orig_file] = idx
        self.mask = []

        for filename in os.listdir(original_test_path):
            CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename]
            CelebA_HQ_filename = CelebA_HQ_filename + ".png"
            self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename))

在这里插入图片描述
在这里插入图片描述

  1. 返回group_pair_GT()函数,继续执行。
@torch.no_grad()
def group_pair_GT():
    ...
    parts_index = {
        "all": None,
        "background": [0],
        "skin": [1],
        "eyebrow": [6, 7],
        "eye": [3, 4, 5],
        "ear": [8, 9, 15],
        "nose": [2],
        "lip": [10, 11, 12],
        "neck": [16, 17],
        "cloth": [18],
        "hair": [13, 14],
    }

    indexes = range(args.n_sample)

    similarity_dict = {}
    parts = parts_index.keys()

    for part in parts:
        similarity_dict[part] = {}

    for src, ref in tqdm(
        itertools.combinations(indexes, 2),
        total=sum(1 for _ in itertools.combinations(indexes, 2)),
    ):
        _, mask1 = dataset[src]
        _, mask2 = dataset[ref]
        mask1 = mask1.to(device)
        mask2 = mask2.to(device)
        for part in parts:
            if part == "all":
                similarity = torch.sum(mask1 == mask2).item() / (images_size ** 2)
                similarity_dict["all"][src, ref] = similarity
            else:
                part1 = torch.zeros(
                    [images_size, images_size], dtype=torch.bool, device=device
                )
                part2 = torch.zeros(
                    [images_size, images_size], dtype=torch.bool, device=device
                )

                for p in parts_index[part]:
                    part1 = part1 | (mask1 == p)
                    part2 = part2 | (mask2 == p)

                intersection = (part1 & part2).sum().float().item()
                union = (part1 | part2).sum().float().item()
                if union == 0:
                    similarity_dict[part][src, ref] = 0.0
                else:
                    sim = intersection / union
                    similarity_dict[part][src, ref] = sim

    sorted_similarity = {}

    for part, similarities in similarity_dict.items():
        all_indexes = set(range(args.n_sample))
        sorted_similarity[part] = []

        sorted_list = sorted(similarities.items(), key=(lambda x: x[1]), reverse=True)

        for (i1, i2), prob in sorted_list:
            if (i1 in all_indexes) and (i2 in all_indexes):
                all_indexes -= {i1, i2}
                sorted_similarity[part].append(((i1, i2), prob))
            elif len(all_indexes) == 0:
                break

        assert len(sorted_similarity[part]) == args.n_sample // 2

    with open(
        f"{args.save_dir}/{args.dataset_name}_test_{args.mask_origin}_sorted_pair.pkl",
        "wb",
    ) as handle:
        pickle.dump(sorted_similarity, handle)

这里是引用

在这里插入图片描述

prepare_data.py调试

以lmdb的方式保存归一化后的图像和图像数量。

  1. 主程序执行,调用prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)函数。
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", type=str)
    parser.add_argument("--size", type=str, default="128,256,512,1024")
    parser.add_argument("--n_worker", type=int, default=1)
    parser.add_argument("--resample", type=str, default="bilinear")
    parser.add_argument("path", type=str)

    args = parser.parse_args()

    resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
    resample = resample_map[args.resample]

    sizes = [int(s.strip()) for s in args.size.split(",")]
    print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))

    imgset = datasets.ImageFolder(args.path)
    with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
        prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)

ImageFolder是torchvision的函数,在读取时路径的设置需要注意,例如图像的路径为/data/test/raw_test/images/目录下的xxx1.jpg … xxxn.jpg,在传入ImageFolder函数时路径应为/data/test/raw_test,而不是/data/test/raw_test/images。

  1. 进入prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS)函数。调用resize_worker函数。
def prepare(
    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
    resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
  1. 进入resize_worker函数,调用resize_multiple(img, sizes=sizes, resample=resample)函数
def resize_worker(img_file, sizes, resample):
    i, file = img_file
    img = Image.open(file)
    img = img.convert("RGB")
    out = resize_multiple(img, sizes=sizes, resample=resample)

    return i, out
  1. 进入resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100)函数, 调用resize_and_convert(img, size, resample, quality)函数。
def resize_multiple(
    img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, resample, quality))

    return imgs
  1. 进入resize_and_convert(img, size, resample, quality=100)函数。
def resize_and_convert(img, size, resample, quality=100):
    img = trans_fn.resize(img, (size, size), resample)
    # img = trans_fn.center_crop(img, size)
    buffer = BytesIO()
    img.save(buffer, format="jpeg", quality=quality)
    val = buffer.getvalue()

    return val
  1. 返回resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100)函数,返回resize_worker函数,返回prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS)函数。
def prepare(
    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
    ...
    files = sorted(dataset.imgs, key=lambda x: x[0])
    files = [(i, file) for i, (file, label) in enumerate(files)]
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f"{size}-{str(i).zfill(5)}".encode("utf-8")

                with env.begin(write=True) as txn:
                    txn.put(key, img)

            total += 1

        with env.begin(write=True) as txn:
            txn.put("length".encode("utf-8"), str(total).encode("utf-8"))

参考资料
torch.nn.Parameter理解
详细介绍Python进度条tqdm的使用
Python itertools模块combinations方法

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值