【图像&三维编辑】DragGAN与Drag3D(原理+代码)

# 系列文章目录

例如:第一章 Python 机器学习入门之pandas的使用



前言

合成满足用户需求的视觉内容通常需要对生成对象的姿态、形状、表达和布局的灵活和精确的可控性。DragGAN:“拖动”图像中的任何点,精确修改图像到目标点。它由两个主要组件组成 1)一种基于特征的运动监督,驱动手柄点向目标位置移动;2)一种新的点跟踪方法,利用鉴别生成器特性来保持定位手柄点的位置。DragGAN可以使图像变形,精确控制像素的位置,从而操纵不同类别的姿态、形状、表达和布局,如动物、汽车、人、景观等。由于这些操作是在学习到的GAN生成图像流形上执行的,它们倾向于产生现实的输出,即使是在具有挑战性的场景下,如幻觉遮挡的内容和变形的形状,一致地遵循物体的刚性。

本文介绍了图像的DragGAN 和3D对象的Drag编辑


题目:Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold
论文:https://arxiv.org/pdf/2305.10973.pdf

一、Introduction

GANs [Goodfelletal.2014],在合成随机逼真图像方面取得了前所未有的成功。在现实应用中,这种基于学习的图像合成方法的一个关键功能要求是对合成视觉内容的可控性。理想的可控图像合成方法应具有以下特性 1)灵活性:能够控制不同的空间属性,包括位置、姿势、形状、表达、布局; 2)精度:能够控制高精度的空间属性; 3)通用性:适用于不同的对象类别。虽然以前的工作只满足这些属性中的一到两个,但我们的目标是在这项工作中实现所有这些属性。

DragGAN允许用户单击图像上任意数量的手柄点和目标点,其目标是驱动手柄点到达相应的目标点,允许用户控制不同的空间属性,并且与对象类别无关

DragGAN解决了两个子问题,包括1)监督手柄点向目标移动,2)跟踪手柄点,以便在每个编辑步骤中都知道它们的位置。 DragGAN前提是,GAN的特征空间具有足够的区别,可以使运动监督和精确的点跟踪。DragGAN还允许用户选择性地绘制感兴趣的区域来执行特定于区域的编辑。DragGAN不依赖于任何额外的网络(如RAFT ),在一个RTX 3090 GPU上只需要几秒钟。

DragGAN在不同的数据集进行评估,包括动物(狮子、狗、猫、马和马)、(脸和全身)、汽车和景观。与传统应用变形方法不同,DragGAN是在学习的图像流形上执行的,它倾向于服从底层的对象结构

二、 RELATED WORK

2.1.用于交互式内容创建的生成模型

目前大多数方法使用生成对抗网络(GANs)或扩散模型进行可控图像合成。

无条件的GANs。GANs是将低维随机采样的潜在向量转换为真实图像的生成模型。他们使用对抗性学习进行训练,并可用于生成高分辨率的逼真图像。大多数GAN模型,如StyleGAN,都不能直接实现对生成的图像进行可控编辑。
条件GANs。网络接收一个条件输入,如分割图或3D变量,以及随机采样的潜在向量生成逼真的图像。EditGAN通过首先建模图像和分割图的联合分布,然后计算与编辑的分割地图对应的新图像来进行编辑。

使用无条件GANs的控制性。目前已经提出了几种通过操纵输入潜在向量来编辑无条件gan的方法。一些方法通过从人工注释或先前的三维模型中进行的监督学习来找到有意义的潜在方向。其他方法以无监督的方式计算潜在空间中的重要语义方向。最近,通过引入中间“斑点”或热图来实现粗物体位置的可控性。所有这些方法都可以编辑与图像对齐的语义属性,如外观,或粗糙的几何属性,如对象的位置和姿态。虽然以风格进行编辑展示了一些空间属性编辑能力,但它只能通过在不同样本之间传输局部语义来实现这一点。与这些方法相比,我们的方法允许用户对空间吸引力进行细粒度的控制。

GAN Warping 也使用了基于点的编辑,但是,它们只支持分布外的图像编辑。一些扭曲的图像可以用来更新生成模型,使所有生成的图像都显示出类似的扭曲。然而,这种方法并不能确保扭曲导致真实的图像。此外,它不启用控制,如更改对象的3D姿态等控制。与我们类似,用户控制的lableLT[通过转换GAN的潜在向量来实现基于点的编辑。但是,这种方法只支持使用拖动图像上的单个点进行编辑,并且不能很好地处理多点约束。另外,控制也不精确,即经过编辑后,往往没有达到目标点。

3d感知GAN。几种方法修改GAN的架构以实现3D控制。在这里,模型生成可以使用基于物理的分析渲染器进行渲染的三维表示。然而,与我们的方法不同的是,控制仅限于全局姿态或照明。

扩散模型。已经实现了高质量的图像合成。这些模型迭代地去噪一个随机采样的噪声,以创建一个逼真的图像。最近的模型显示,以文本输入为条件的表达性图像合成。然而,自然语言并不能支持对图像的空间属性进行细粒度的控制,因此,所有的文本条件方法都被限制为高级的语义编辑。此外,目前的扩散模型速度较慢,因为它们需要多个去噪步骤。虽然在有效采样方面取得了进展,但GANs仍然明显更有效。

2.2.点跟踪

为了跟踪视频中的点,一个明显的方法是通过连续帧之间的光流估计。光流估计是估计两幅图像之间运动场的经典问题。传统的方法用手工制作的标准来解决优化问题;基于深度学习的方法有更好的性能,通常使用具有标注的光流的合成数据来训练深度神经网络。其中,目前应用最广泛的方法是RAFT,它通过迭代算法估计光流。最近,Harley等人[2022]将这种迭代算法与传统的“粒子视频”方法结合起来,产生了一种新的点跟踪方法PIPs。pip考虑跨多个帧的信息,因此比以前的方法能更好地处理远程跟踪。

在这项工作中,我们展示了在不使用上述任何方法或其他神经网络的情况下,对GAN生成图像执行点跟踪的可能性。结果表明,GAN的特征空间具有足够的鉴别性,可以简单地通过特征匹配来实现跟踪。虽然之前的一些工作也利用了语义分割中的鉴别特征,我们首先将基于点的编辑问题与区分GAN特征的直觉联系起来,并设计了一种具体的方法。摆脱额外的跟踪模型可以使我们的方法更有效地运行,以支持交互式编辑。尽管我们的方法很简单,但在我们的实验中,我们证明了它优于最先进的点跟踪方法,包括RAFT和pip。

三、DragGAN

我们的研究基于StyleGAN2体系结构:

StyleGAN 的专业术语。在StyleGAN2架构中,一个512维的潜在码 𝒛∈N(0,𝑰)通过一个映射网络被映射到一个中间的潜在码 𝒘∈R512。然后𝒘被发送到生成器 𝐺,以生成输出图像I=𝐺(𝒘)。在这个过程中,𝒘被复制了几次,并发送到生成器𝐺的不同层,以控制不同级别的属性。或者,我们也可以对不同的层使用不同的𝒘,在这种情况下,输入将是 𝒘∈R𝑙×512=W+,其中𝑙是层数。这种约束较少的W+空间更具表现力。由于生成器𝐺学习从低维潜在空间到更高维图像空间的映射,它可以看作是建模一个图像流形。

3.1 交互式的图像编辑

整体pipline 如图所示,对具styleGAN2 中latent code 𝒘的图像I∈R3×𝐻×𝑊做移动。假设原始图像的潜在空间像素点为:
在这里插入图片描述
对应到交互转换后的潜在空间像素点为(即 𝒑𝑖 的对应目标点是 𝒕𝑖):

在这里插入图片描述

在这里插入图片描述
还允许用户选择性地绘制一个二进制掩模M,表示图像的哪个区域是可移动的

每个优化步骤由两个子步骤组成,包括 1)运动监控; 2)点跟踪

1.运动监督:经过一次优化,得到了一个新的潜在代码𝒘‘和一个新的图像I’,导致图像中对象的轻微移动。移动的损失被用来优化潜在代码𝒘

请注意,运动监督步骤只将每个手柄点向目标移动一小的步骤,但步骤的确切长度不清楚,因为它受到复杂的优化动态,因不同的对象和部件而不同。

2.我们然后更新句柄点{𝒑𝑖}的位置,以跟踪对象上相应的点

这个跟踪过程是必要的,因为如果手柄点(例如,狮子的鼻子)没有被准确地跟踪,那么在下一个运动监督步骤中,错误的点(例如,狮子的脸)将被监督,导致不希望的结果。

经过跟踪后,根据新的 pi 和latent code,重复上述优化步骤

这个优化过程一直进行,直到手柄点{𝒑𝑖}到达目标点{𝒕𝑖}的位置,在我们的实验中,这通常需要30-200次迭代。用户还可以在任何中间步骤中停止优化。编辑完成后,用户可以输入新的句柄和目标点,并继续编辑,直到对结果满意为止。

3.2 运动监督

提出了一个不依赖于任何附加的神经网络的。运动监督损失:,Generator 的中间特征具有区别性一个简单的损失就足以监督运动。具体采用StyleGAN2的第6块之后的特征映射F,由于在分辨率和鉴别性之间有很好的权衡,它在所有特征中表现最好。我们通过双线性插值来调整F的大小,使其具有与最终图像相同的分辨率。如图3所示,为了将一个手柄点𝒑𝑖移动到目标点𝒕𝑖,我们的想法是监督𝒑𝑖周围的一个小斑块(红色圆圈),通过一个小的步骤(蓝色圆圈)向𝒕𝑖移动。我们使用Ω1(𝒑𝑖,𝑟1)来表示距离𝒑𝑖小于𝑟1的像素,那么我们的运动监督损失为:
在这里插入图片描述

损失函数
在这里插入图片描述

式中,F(𝒒)表示 F在像素 𝒒处的特征值,𝒅𝑖= 是指向𝒑𝑖到𝒕𝑖的归一化向量(𝒅𝑖=0,如果𝒕𝑖=𝒑𝑖),F0是与初始图像对应的特征映射。注意,第一项是对所有{𝒑𝑖}求和。由于𝒒𝑖+𝒅𝑖的分量不是整数,我们通过双线性插值得到F(𝒒𝑖+𝒅𝑖)。重要的是,当使用这种损失执行反向传播时,梯度不会通过F(𝒒𝑖)进行反向传播。这将激励𝒑𝑖转移到𝒑𝑖+𝒅𝑖,但反之亦然。在给出二值掩模M的情况下,我们保持未掩蔽区域固定,重建损失显示为第二项。在每个运动监控步骤中,该损失用于优化一步的latent code 𝒘

𝒘可以在W空间或在W+空间中进行优化(这取决于用户是否想要一个更受约束的图像流形)。W+空间更容易实现分布外操作(如下图中的cat),本工作使用W+来获得更好的可编辑性。在实践中,我们观察到图像的空间属性主要受前6层的𝒘的影响,而其余的空间属性只影响外观。因此,受风格混合技术的启发[Karras等人,2019年],我们只更新了前6层的𝒘,同时修复了其他层以保持外观。这种选择性的优化导致了所期望的图像内容的轻微移动。

在这里插入图片描述

3.3 点跟踪

之前的运动监督产生了一个新的潜在代码𝒘‘,新的特征映射F’,和一个新的图像I‘。由于运动监督步骤不容易提供移动点的精确新位置,我们在这里的 目标是更新每个𝒑𝑖,使它跟踪对象上的相应点。点跟踪通常通过光流估计模型或粒子视频方法来执行,我们提出了一种新的点跟踪方法:GANs的鉴别特征可以很好地捕获密集的对应关系,从而通过在特征块中的最近邻搜索来有效地进行跟踪。具体来说,我们将初始句柄点的特征表示为𝒇𝑖=F0(𝒑𝑖)。我们把𝒑𝑖周围的补丁表示为Ω2(𝒑𝑖,𝑟2)={(𝑥,𝑦) | |𝑥−𝑥𝑝,𝑖|<𝑟2,|𝑦−𝑦𝑝,𝑖|<𝑟2}.然后通过在Ω2(𝒑𝑖,𝑟2)中搜索𝑓𝑖的最近邻,得到跟踪点。

在这里插入图片描述
通过这种方式,𝒑𝑖将被更新为跟踪该对象。对于多个操作点 pi,对每个点应用相同的过程。注意,这里我们也在考虑StyleGAN2的第6个块之后的特征映射F‘。特征图的分辨率为256×256(可双线性插值到原图大小)。

四、GET3D(英伟达:噪声---->3D物体)

题目:A Generative Model of High Quality 3D Textured Shapes Learned from Images (NeurIPS 2022 )
代码:https://github.com/nv-tlabs/GET3D

三维mesh生成效果:
在这里插入图片描述
GET3D 包括两个分支:

1.几何分支:可微的输出任意拓扑的表面mesh

2.纹理分支:根据查询的表面点来产生 texture field,还可以扩展到表面的其他属性,比如材质

训练过程中,一个有效的可微栅格器将生成的带纹理 3D 模型投影到 2D 的高分辨率图片。整个过程都是可微分的,使得整个对抗训练可以从 discriminator 传递到两个分支

1.3D 纹理网格生成器

生成目标是 mesh M 和 texture E

输入是采样的高斯分布 z1 和 z2,通过非线性映射网络得到 w1 和 w2

w1 用来控制 3D 模型的形状
w2 用来控制 3D 模型的纹理

非线性网络是 8 层的 MLP 网络,每层是 512 维和 leaky-ReLU 的激活函数。

整个结构如下图所示:
在这里插入图片描述

1.1几何生成器

GET3D 的几何生成器包含最近提出的可微分表面表征 DMTet。DMTet 将 3D 模型表示成一个可变形三角面片的四面体(tetrahedron)上面的符号距离场(signed distance field,SDF),从四面体可以可微分的恢复 3D 模型mesh。通过移动顶点来使得 mesh 变形,从而更好的利用它的分辨率。通过使用 DMTet,就可以生成任意拓扑结构和类别的 3D 模型。

四面体的结构如下图所示:

在这里插入图片描述

网络结构

1.从 512 维的 w1 生成 SDF 值和顶点
2.首先使用条件 3D 卷积生成条件为 w1 的特征向量
3.然后使用三线性插值在每个顶点上查询该特征,输入 MLP,输出 SDF 的值 si 和偏移量(deformation) Δvi

可微分的 mesh 提取器

1.SDF 的值 si 和偏移量(deformation) Δvi 以后,使用可微分的匹配四面体算法提取显示的 mesh。匹配四面体是指基于 SDF 的值
决定每个四面体的表面拓扑结构
2.具体来说,一个 mesh 的表面是由顶点和表面的符号距离决定的,如果两个顶点的符号距离不相同,代表两个符号中间存在 mesh 表面的顶点 mij
3.mij 的梯度可以通过反向传播传导到 SDF 的值 si 和偏移量(deformation) Δvi

1.2.纹理生成器

GET3D 将纹理参数化为 texture field,具体方法是通过一个函数 ft 来把具体的表面点 p 在以 w2 为条件的时候,映射到颜色 c (RGB 表示)。因为 texture field 对于几何是有依赖的,所以输入也依赖于 w1 。最终的公式如下,两个条件变量通过 concat 的方式组合。

在这里插入图片描述

网络结构

1.texture field 表示为三平面表征,对于 3D 物体重建和 3D-aware 的图片生成都是很高效的
2.具体而言,使用 2D 卷积神经网络将 latent code w1w2 映射到三维正交特征平面(three axis-aligned orthogonal feature planes),形状是 N×N×C×3N = 256 是空间的分辨率, C =32 是通道数量
3.获得特征平面以后,mesh 表面某个位置 p 的可以通过公式 在这里插入图片描述
获得特征 f.t ,再通过一个全连接来映射成 RGB 颜色 c , 在这里插入图片描述代表将 mesh 表面点 p 映射到特征平面的映射函数 ;在这里插入图片描述函数代表特征的双线性插值.

与其他的 3D-aware 图像生成工作相同的是,都使用了 neural field representation,不同的是,GET3D 只需要采样 mesh 表面点上的 texture field,而不需要沿射线的密集采样。有效的降低了渲染高分辨率图像的计算复杂度,并且保证了生成的模型在多个视角都能保持一致。

2.可微分渲染和训练

GET3D 将生成的 3D mesh 和 texture field 使用可微分渲染器渲染到 2D 的图片,使用 2D 的辨别器来进行监督,辨别器需要分辨真实物体的图片和生成物体的渲染图片。

1.可微分渲染

我们假设图像所使用的相机分布是已知的,为了渲染生成的 3D shape,通过随机采样相机的分布,使用高度优化的可微分栅格器 Nvdiffrast 来把 3D mesh 渲染到一个 2D 轮廓和图片,每个像素包含了 mesh 上对应的 3D 点。再利用三维坐标到 texture field 中获取颜色。

2.鉴别器 和 训练目标

GET3D 通过对抗目标来进行训练。采用 StyleGAN 的辨别器网络,使用 R1 正则化来确认相同的非饱和GAN 训练目标。通过经验,作者们发现使用两个独立的辨别器,一个用来辨别 RGB 图片,一个用来辨别轮廓,能够产生更好的结果。
代表辨别器,x 代表 RGB 图片或者轮廓。整个对抗训练目标可以定义为如下所示:
在这里插入图片描述
在这里插入图片描述
微分的,因此梯度可以传导回 3D 的生成器中。

正则化

为了删除 mesh 内部不可见的部分,对几何生成器做相邻点 SDF 值的交叉熵,并且增加了正则化。
在这里插入图片描述
四面体的每一条边的两个顶点 vi,vj, 上做如上损失,H 代表二元交叉熵,σ 代表 sigmoid 激活函数。其中要求 两个顶点 SDF 的值得符号不相同。

整个 Loss 定义如下

在这里插入图片描述

3.生成效果

下图最后两列是 GET3D 生成得不带纹理和带纹理得 mesh,可以看到相比之前得工作,效果有了非常显著得提升,mesh 的细节更加的精细。

在这里插入图片描述

通过修改 latent code 来改变生成模型的形状和颜色。

在这里插入图片描述

五、drag3D代码(dragGAN+get3D)

0.噪声w生成三平面特征

输入:wgeo (1,7,512)和wtex(1,7,512) ,以及初始化的立方体:self.const = torch.nn.Parameter(512,4,4), 生成mesh。经过几次卷积,得到img,就是最后的两种 三平面特征:
x, img = block(x, img, cur_ws_tex, cur_ws_geo, **block_kwargs) # None None (1,1,512) (1,2,512) --> (1,128,256,256) (1,192,256,256)

重点是其中的block,利用卷积和仿射变换,将输入的 2 个w ,映射到输出 x和img

plane_feat = self.tri_plane_synthesis(ws_tex[:, :self.num_ws_tex], ws_geo[:, :self.num_ws_geo], **block_kwargs)
# 以下是展开:

    def forward(self, ws_tex, ws_geo, **block_kwargs):
        block_ws_tex = []
        block_ws_geo = []
        with torch.autograd.profiler.record_function('split_ws'):
            w_idx_tex = 0
            w_idx_geo = 0
            for res in self.block_resolutions:                   # 循环,每次提升分辨率[4, 8, 16, 32, 64, 128, 256]
                block = getattr(self, f'b{res}')
                block_ws_tex.append(ws_tex.narrow(1, w_idx_tex, block.num_torgb))    # 每次循环append,裁剪其中的一维(1,1,512) 
                block_ws_geo.append(ws_geo.narrow(1, w_idx_geo, block.num_togeo + block.num_conv))   # (1,1,512)(1,3,512)(1,3,512)...
                w_idx_geo += (block.num_conv + block.num_togeo)   # 2 5 8 11...20
                w_idx_tex += (block.num_torgb)                    # 1 2 3 4...7
        x = img = None
        for res, cur_ws_tex, cur_ws_geo in zip(self.block_resolutions, block_ws_tex, block_ws_geo):
            block = getattr(self, f'b{res}')
            x, img = block(x, img, cur_ws_tex, cur_ws_geo, **block_kwargs)
        return img

block网络的内容:
在这里插入图片描述

block中的具体操作:

---------------1.首次融合:x = self.conv1(x, next(w_geo_iter), fused_modconv=True)  # (1,512,4,4)----------
styles = self.affine(w)        # (1,512) --> (1,512)
x = modulated_conv2d( x=x, weight=self.weight, styles=styles, noise=noise, up=1,
            padding=self.padding, resample_filter=self.resample_filter, flip_weight=True,
            fused_modconv=True)
# padding=1, noise:torch.rand 生成 self.weight:nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])(512,512,3,3)
    def modulated_conv2d():
        # 相当于一个分解卷积,其中的权重和偏置,都是自己设置
        w = weight * styles.reshape(batch_size, 1, -1, 1, 1) 
        if demodulate:
            dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt()  # (1,512,512,3,3)
        if demodulate and fused_modconv:
            w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1)    # (1,512,512,3,3)
        x = conv2d_resample.conv2d_resample( x=x, w=w.to(x.dtype), f=resample_filter(4,4), up=1, down=1, padding=1, groups=1, flip_weight=True)
            conv2d_gradfix.conv2d(x, w, stride=stride, padding=padding, groups=groups)
            torch.nn.functional.conv2d(input=x, weight=weight, bias=None, stride=1, padding=[1,1], dilation=1, groups=1)
        x = x.add_(noise)     # (4,4)
x = bias_act.bias_act(x, self.bias.to(x.dtype), act='lrelu', gain=1.414, clamp=None)    # self.bias=nn.Parameter(out_channel=512)    


--------------------------2.分别生成 geo 和 tex 特征----------------------------------
 # next(w_geo_iter) 取的是输入w_geo第二维(1,512),上一步用的是第一维(1,512)
 
geo_y = self.togeo(x, next(w_geo_iter), fused_modconv=fused_modconv)        # w_geo生成geo特征 (1,96,4,4) 
tex_y = self.totex(x, next(w_tex_iter), fused_modconv=fused_modconv)


# 以上两个网络,仍然是 affine + modulated_conv2d 操作。

new_img = torch.cat([geo_y, tex_y], dim=1)                                        # (1,192,4,4)

最终的 new_img 维度为 (1,192,256,256) ,包含了 sdf 与 tex 特征。以上代码包含于以下整体代码

ws_geo = torch.cat([self.ws_geo_nonparam, self.ws_geo_param], dim=1)        # (1,22,512)
ws_tex = self.mesh.ws_tex                                                   # (1,9,512)
sdf_feature, tex_feature = self.model.G.synthesis.generator.get_feature(
            ws_tex[:, :self.model.num_ws_tex_triplane], # 1,7,512
            ws_geo[:, :self.model.num_ws_geo_triplane])  # 1,20,512
# sdf_feature:[1, 96, 256, 256] , triplane features三平面特征

1.得到 drag point pairs (no need to have grad):在3维模型中拖拽的两对点

在这里插入图片描述

with torch.no_grad():
     mask_points = torch.tensor(self.points_mask, dtype=torch.bool, device=self.device)    # [true, true]
     source_points = torch.tensor(self.points_3d, dtype=torch.float32, device=self.device)[mask_points] # [N, 3]
     # self.points_3d:两个三维点[[-0.3002126, 0.0872703, -0.13610830], [-0.3105514, 0.0829553, 0.13541322]]

     target_points = source_points + torch.tensor(self.points_3d_delta, dtype=torch.float32, device=self.device)[mask_points]
     # self.points_3d_delta:两个三维向量[[0.052246, -0.049535, -0.0102899], [0.0412190, -0.0301212, 0.017559306]]

     directions = safe_normalize(target_points - source_points)   #[ 0.7184,-0.6811, 0.1415],[0.7635,-0.5579, 0.3252]]

随后,将单个drag点,扩大到一个半径内,通过在point上,添加预设的 self.offsets1,维度(343,3):
在这里插入图片描述

     resolution = sdf_feature.shape[-1]       # sdf特征的分辨率 256
     step_size = 0.1 / resolution             # 0.003 重要! 足够小,使 point tracking 变得可能
                
     # expand source to a patch based on radius
     patched_points = source_points.unsqueeze(0) + step_size * self.offsets1.unsqueeze(1) # [B, N, 3]
     B, N = patched_points.shape[:2]

     # shift points
     shifted_points = patched_points + step_size * directions # [B, N, 3]34323

2.计算 运动监督 的损失

patched_feat = self.model.G.synthesis.generator.get_sdf_def_prediction(sdf_feature, patched_points.reshape(1, -1, 3), return_feats=True).reshape(B, N, -1) # [B, N, C] (343,2,32)
shifted_feat = self.model.G.synthesis.generator.get_sdf_def_prediction(sdf_feature, shifted_points.reshape(1, -1, 3), return_feats=True).reshape(B, N, -1) # [B, N, C]

loss = F.l1_loss(shifted_feat, patched_feat.detach())
loss.backward()

其中,get_sdf_def_prediction,用于预测 四面体顶点的SDF 和 deformation 。代码如下:

    def get_sdf_def_prediction(self, sdf_feature, position, ws_geo=None, return_feats=False):
        '''
        Predicting SDF and deformation for the vertices
        :参数 sdf_feature: geometry的三平面特征   (1,96,256,256)
        :参数 position: location for the tetrahedral grid vertices 四面体网格顶点的位置 (1,686,3)
        :参数 ws_geo: latent code for geometry       None
        :return:
        '''
        # 三平面特征,分成三份
        tri_plane = torch.split(sdf_feature, self.img_feat_dim, dim=1)
        
        # 归一化
        normalized_tex_pos = (position - self.shape_min) / self.shape_length
        normalized_tex_pos = torch.clamp(normalized_tex_pos, 0, 1) 
        normalized_tex_pos = normalized_tex_pos * 2.0 - 1.0                # (1,686,3)

        # 利用 torch.nn.functional.grid_sample 函数,在特征中插值二维的顶点
        x_feat = grid_sample_gradfix.grid_sample(
            tri_plane[0],                                                  # (1,32,256,256)
            torch.cat(
                [normalized_tex_pos[:, :, 0:1], normalized_tex_pos[:, :, 1:2]],
                dim=-1).unsqueeze(dim=1).detach())                         # (1,1,686,2)    --> (1,32,1,686)

        y_feat = grid_sample_gradfix.grid_sample(
            tri_plane[1],
            torch.cat(
                [normalized_tex_pos[:, :, 1:2], normalized_tex_pos[:, :, 2:3]],
                dim=-1).unsqueeze(dim=1).detach())                                        # (1,32,1,686)

        z_feat = grid_sample_gradfix.grid_sample(
            tri_plane[2],
            torch.cat(
                [normalized_tex_pos[:, :, 0:1], normalized_tex_pos[:, :, 2:3]],
                dim=-1).unsqueeze(dim=1).detach())

        final_feat = (x_feat + y_feat + z_feat)                   # (1,32,1,686)
        final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1)   #  (1,686,32)

        if return_feats:
            return final_feat

3.point tracking (update points_3d)

with torch.no_grad():
     source_feat = patched_feat[(B - 1) // 2] # [N, C] (2,32) 找到343个点,最中间点的特征

     # expand source to a patch based on a larger radius
     patched_points = source_points.unsqueeze(0) + step_size * self.offsets1.unsqueeze(1) # [B, N, 3] (343,2,3)
                
     # 计算更新的 sdf 特征     
     ws_geo = torch.cat([self.ws_geo_nonparam, self.ws_geo_param], dim=1)         # (1,22,512)
     ws_tex = self.mesh.ws_tex                                                    # (1,9,512)
     new_sdf_feature, new_tex_feature = self.model.G.synthesis.generator.get_feature(
            ws_tex[:, :self.model.num_ws_tex_triplane], # 7
            ws_geo[:, :self.model.num_ws_geo_triplane] # 20
            ) # [1, 96, 256, 256] x 2, triplane features

     # 渲染得到最新的 点集的 三平面特征
     new_patched_feat = self.model.G.synthesis.generator.get_sdf_def_prediction(new_sdf_feature, patched_points.reshape(1, -1, 3), return_feats=True).reshape(B, N, -1) # [B, N, C] 343,2,32

     # 找到最近的点
     dist = torch.mean((new_patched_feat - source_feat) ** 2, dim=-1) # [B, N] (343,2)
     indices = torch.argmin(dist, dim=0) # [N] [304,304]
                

     # 新点替换旧点 (更新points_3d 和 距离delta)
     new_source_points = torch.gather(patched_points, dim=0, index=indices.view(1,-1,1).repeat(1,1,3)).squeeze(1) # [N, 3]
     new_points_delta = target_points - new_source_points # [N, 3]

     new_source_points_with_deleted[np.array(self.points_mask)] = new_source_points.detach().cpu().numpy()
     new_source_points_delta_with_deleted[np.array(self.points_mask)] = new_points_delta.detach().cpu().numpy()

     self.points_3d = new_source_points_with_deleted.tolist()                # 新的半径内,跟原始点特征最近的点
     self.points_3d_delta = new_source_points_delta_with_deleted.tolist()    # 新点跟目标点的差

4.更新渲染的mesh

# 用给定的latent code 产生 mesh ,self.mesh.ws_geo_last是latent code
v, f, sdf, deformation, v_deformed, sdf_reg_loss = self.model.G.synthesis.get_geometry_prediction(self.mesh.ws_geo_last, new_sdf_feature)
 
---------# Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
     sdf, deformation, sdf_reg_loss = self.get_sdf_deformation_prediction(ws, sdf_feature=sdf_feature)
           init_position = self.dmtet_geometry.verts.unsqueeze(dim=0)   #(1,98653,3)
           sdf, deformation = self.generator.get_sdf_def_prediction( sdf_feature, ws_geo=ws, position=init_position)
           sdf = self.mlp_synthesis_sdf(ws_geo, final_feat)  # ws_geo(12512) final_feat(1, 98653,32)
                 styles = self.affine(w)       # (1,512)->(1,32)
                 x = modulated_fc(x=x, weight=self.weight, styles=styles, noise=noise)  # weight 和styles相乘,作为fc权重 得到sdf:(1, 98653, 32)
                 
           sdf = self.mlp_synthesis_sdf(ws_geo, final_feat)         #  module_fc --> (1, 98653, 1)
           deformation = self.mlp_synthesis_def(ws_geo, final_feat) #  module_fc --> (1, 98653, 3)

---------# Step 2: Normalize the deformation 来防止 flipped triangles.

     deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)    # 90*1*tanh
     sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)               # batch_size=1


---------# Step 3:修复一些sdf,如果观察到空的形状(全正或全负)(基本不会执行,可跳过)

        pos_shape = torch.sum((sdf.squeeze(dim=-1) > 0).int(), dim=-1)       # 4816
        neg_shape = torch.sum((sdf.squeeze(dim=-1) < 0).int(), dim=-1)       # 93837
        zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
        if torch.sum(zero_surface).item() > 0:
            update_sdf = torch.zeros_like(sdf[0:1])
            max_sdf = sdf.max()
            min_sdf = sdf.min()
            update_sdf[:, self.dmtet_geometry.center_indices] += (1.0 - min_sdf)  # greater than zero
            update_sdf[:, self.dmtet_geometry.boundary_indices] += (-1 - max_sdf)  # smaller than zero
            new_sdf = torch.zeros_like(sdf)
            for i_batch in range(zero_surface.shape[0]):
                if zero_surface[i_batch]:
                    new_sdf[i_batch:i_batch + 1] += update_sdf
            update_mask = (new_sdf == 0).float()
            # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
            sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
            sdf_reg_loss = sdf_reg_loss * zero_surface.float()
            sdf = sdf * update_mask + new_sdf * (1 - update_mask)

---------# Step 4: 移除 bad sdf的梯度  (full positive or full negative)(基本不会执行,可跳过)

        final_sdf = []
        final_def = []
        for i_batch in range(zero_surface.shape[0]):
            if zero_surface[i_batch]:
                final_sdf.append(sdf[i_batch: i_batch + 1].detach())
                final_def.append(deformation[i_batch: i_batch + 1].detach())
            else:
                final_sdf.append(sdf[i_batch: i_batch + 1])
                final_def.append(deformation[i_batch: i_batch + 1])
        sdf = torch.cat(final_sdf, dim=0)
        deformation = torch.cat(final_def, dim=0)
        return sdf, deformation, sdf_reg_loss

dmtet:从随机顶点,生成mesh的顶点verts 和面片faces

        # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
        sdf, deformation, sdf_reg_loss = self.get_sdf_deformation_prediction(ws, sdf_feature=sdf_feature)  # (1, 98653, 1)(1, 98653, 3)
        v_deformed = self.dmtet_geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation # (1, 98653, 3)
        tets = self.dmtet_geometry.indices                                                                 # (1, 98653, 4)
        n_batch = ws.shape[0]
        v_list = []
        f_list = []

        # Step 2: Using marching tet to obtain the mesh
        for i_batch in range(n_batch):
            verts, faces = self.dmtet_geometry.get_mesh(v_deformed[i_batch], sdf[i_batch].squeeze(dim=-1), with_uv=False, indices=tets)   # (15166.3)(30652,3)
            # 其中,fverts为归一化的三维坐标,faces为015166的顶点索引
            v_list.append(verts)

放射变换的函数:
在这里插入图片描述

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值