Few Shot Generative Model Adaption via Relaxed Spatial Structural Alignment
1.引言
在小样本学习领域,生成模型适应性的研究对于提升模型在有限数据条件下的学习能力至关重要。传统的生成模型通常需要大量数据来训练,而在实际应用中,尤其是在特定的或者罕见的场景中,获取大量标注数据通常是不切实际的。小样本学习旨在解决这一问题,通过开发能够从少量样本中快速适应和学习的模型,来推广模型的应用范围,提高其在特定任务上的性能。有效的生成模型适应性不仅可以增强模型对新任务的快速适应能力,还能在数据稀缺的领域中发挥关键作用,如医学图像分析、生物信息学和其他需要定制化模型的应用。
面临的挑战
使用少样本的生成模型在适应性方面面临的主要挑战包括:
-
模型过拟合与崩溃:在极少样本的情况下(少于10个),常规的生成模型通常会产生质量较差的图像,并且容易发生早期崩溃。
-
身份退化和非自然失真:即使是通过对比学习提出的保持源域实例间相对距离的方法,也无法保证每个图像的固有结构,导致样本在目标域空间中漂移,最终造成身份退化和非自然的失真或纹理变化。
-
现有方法的局限性:尽管一些方法通过微调策略直接建模目标域的分布,或通过施加强正则化、轻微扰动网络参数、数据增强等方法来避免过拟合,但这些方法通常只适用于训练图像数量超过100的情况。
针对这些挑战,论文提出了一种新的适应性方法,即放松的空间结构对齐(RSSA),该方法利用源域图像更丰富的空间结构先验来解决生成模型的身份退化问题。
2.Relaxed Spatial Structural Alignment (RSSA)
Relaxed Spatial Structural Alignment (RSSA) 的详细工作原理包括以下几个关键部分:
-
生成模型的预训练与适应:在RSSA中,首先有一个在大规模源域数据集 D s D_s Ds上预训练的生成器 G s G_s Gs。这个生成器将从潜在空间 z ∼ p ( z ) ⊂ R d z \sim p(z) \subset \mathbb{R}^d z∼p(z)⊂Rd中采样的噪声向量 z z z映射到像素空间中的生成图像 G s ( z ) G_s(z) Gs(z)。少样本适应的目标是将 G s G_s Gs从源域适应到目标域,并使用目标域中的少量样本获取目标生成器 G t G_t Gt。通常,这是通过用 G s G_s Gs 初始化 G t G_t Gt并在目标域数据集 D t D_t Dt 上微调 G t G_t Gt来完成的。
-
跨域空间结构一致性损失: L G s ↔ G t L_{Gs \leftrightarrow Gt} LGs↔Gt是RSSA方法的核心,由自相关一致性损失和干扰相关一致性损失组成。这一损失函数帮助保持源域和目标域生成的图像对之间的结构一致性,通过保持图像的内在空间结构和变化趋势来解决身份退化和图像失真问题。
-
什么是跨域空间结构一致性损失?
跨域空间结构一致性损失(Cross-domain spatial structural consistency loss)是为了解决少样本生成模型适应问题而提出的一种创新损失函数。这一损失函数包括两个主要部分:自相关一致性损失(self-correlation consistency loss)和干扰相关一致性损失(disturbance correlation consistency loss)。这两部分共同作用于模型,以保持和传递源域中图像的内在空间结构及其变化趋势到目标域。
-
自相关一致性损失:用于约束图像的固有结构。通过这一部分,模型学习在源域和目标域生成的图像对之间保持相似的自相关特征,从而确保图像在结构上的一致性。
自相关一致性损失(Self-correlation consistency loss)的计算在论文中的具体实现如下:
-
特征图的自相关矩阵:对于每个卷积层,使用该层的特征图 f l ∈ R c × w × h f_l \in \mathbb{R}^{c \times w \times h} fl∈Rc×w×h 来形成图像的内在结构信息。这里, f l ( x , y ) f_l(x, y) fl(x,y) 是一个 c c c 维向量,表示在 l l l 层中位置 ( x , y ) (x, y) (x,y) 的特征。
-
自相关矩阵的计算:对于位置 ( x , y ) (x, y) (x,y) 在 l l l 层的自相关矩阵 C x y l ∈ R w × h C_{xy}^l \in \mathbb{R}^{w \times h} Cxyl∈Rw×h,每个元素 C x y l ( i , j ) C_{xy}^l(i, j) Cxyl(i,j) 可以通过以下公式计算:
C x y l ( i , j ) = cos ( f l ( x , y ) , f l ( i , j ) ) C_{xy}^l(i, j) = \cos(f_l(x, y), f_l(i, j)) Cxyl(i,j)=cos(fl(x,y),fl(i,j))
其中, cos ( ⋅ ) \cos(\cdot) cos(⋅) 表示余弦相似度函数, ( i , j ) (i, j) (i,j) 是 f l f_l fl 中的对应位置。 -
损失函数的计算:在源域生成器 G s G_s Gs 和目标域生成器 G t G_t Gt 之间计算空间自相关一致性损失:
L s c c ( G t , G s ) = E z i ∼ p ( z ) smooth 1 ( C x y t , l , C x y s , l ) L_{scc}(G_t, G_s) = \mathbb{E}_{z_i \sim p(z)} \text{smooth}_1(C_{xy}^{t,l}, C_{xy}^{s,l}) Lscc(Gt,Gs)=Ezi∼p(z)smooth1(Cxyt,l,Cxys,l)
其中 C x y t , l C_{xy}^{t,l} Cxyt,l和 C x y s , l C_{xy}^{s,l} Cxys,l 分别指示 G t G_t Gt 和 G s G_s Gs 在 l l l 层的自相关矩阵。 -
效率优化:由于自相关矩阵的计算是 O ( ( w ⋅ h ) 2 ) O((w \cdot h)^2) O((w⋅h)2) 操作,对于高分辨率的特征图,论文首先通过平均池化聚合邻近特征向量,并将整个特征图分解为补丁来计算局部自相关矩阵。
通过这种方法,自相关一致性损失帮助确保在源域和目标域生成的图像在结构上保持一致,这对于在少样本生成模型适应过程中保持图像身份的一致性和减少结构失真至关重要。
Self-correlation consistency loss计算过程的pytorch示例:
import torch import torch.nn.functional as F def self_correlation_map(feature_map): """ 计算特征图的自相关矩阵 :param feature_map: 特征图,形状为 (C, H, W) :return: 自相关矩阵 """ C, H, W = feature_map.size() feature_map = feature_map.view(C, -1) # (C, H*W) normalization_factor = torch.norm(feature_map, p=2, dim=0, keepdim=True) feature_map = feature_map / normalization_factor # 归一化 correlation_map = torch.mm(feature_map.t(), feature_map) # (H*W, H*W) return correlation_map def self_correlation_consistency_loss(source_feature_maps, target_feature_maps): """ 计算自相关一致性损失 :param source_feature_maps: 源域生成图像的特征图列表 :param target_feature_maps: 目标域生成图像的特征图列表 :return: 损失值 """ loss = 0.0 for source_feature_map, target_feature_map in zip(source_feature_maps, target_feature_maps): source_correlation_map = self_correlation_map(source_feature_map) target_correlation_map = self_correlation_map(target_feature_map) # 使用 smooth L1 损失 loss += F.smooth_l1_loss(source_correlation_map, target_correlation_map) return loss # 示例使用 # 假设 source_feature_maps 和 target_feature_maps 是从两个生成器得到的特征图列表 # loss = self_correlation_consistency_loss(source_feature_maps, target_feature_maps)
这段代码首先定义了一个函数
self_correlation_map
来计算给定特征图的自相关矩阵。接着定义了self_correlation_consistency_loss
函数来计算源域和目标域生成图像之间的自相关一致性损失。这里使用的是平滑 L1 损失(smooth L1 loss),这有助于减少异常值对损失的影响。请注意,这段代码假设
source_feature_maps
和target_feature_maps
是包含相应层特征图的列表,并且它们的长度和维度应该匹配。在实际应用中,这些特征图需要从对应的生成模型中提取。 -
-
干扰相关一致性损失:用于约束特定干扰下图像的变化趋势。这一部分帮助模型在源域和目标域之间传递图像在小干扰下的空间变化特征,从而保持图像身份的连贯性和一致性。
在论文中,干扰相关一致性损失(Disturbance correlation consistency loss)的计算过程如下:
-
定义输入噪声邻域:对于一个输入噪声向量 z i z_i zi,定义一个半径为 r r r 的邻域 U ( z i , r ) = { z ∣ ∣ z − z i ∣ < r } U(z_i, r) = \{z \,|\, |z - z_i| < r\} U(zi,r)={z∣∣z−zi∣<r}。然后,从这个邻域中采样 N N N 个噪声向量,并形成一个包含 N + 1 N + 1 N+1 个向量的批次 z n 1 N + 1 {z_n}^{N+1}_1 zn1N+1,来代表这个邻域。
-
计算像素级空间相互相关性:对于 l l l 层的特征图 f l f_l fl,定义 D j k l D_{jk}^l Djkl 为任意两个样本 z j z_j zj 和 z k z_k zk 之间的像素级空间相互相关性。在位置 ( x , y ) (x, y) (x,y)处的 D j k l D_{jk}^l Djkl 由 f j l f_{jl} fjl 中 ( x , y ) (x, y) (x,y) 处的特征向量与 f k l f_{kl} fkl 中一个小对应区域 Q Q Q(由一个滑动窗口定义)中的特征向量之间的相似性的 softmax 计算得出:
D j k l ( x , y ) = Softmax ( { cos ( f j l ( x , y ) , f k l ( m , n ) ) } ( m , n ) ∈ Q ) D_{jk}^l(x, y) = \text{Softmax}(\{\cos(f_{jl}(x, y), f_{kl}(m, n))\}_{(m,n) \in Q}) Djkl(x,y)=Softmax({cos(fjl(x,y),fkl(m,n))}(m,n)∈Q)
其中, cos ( ⋅ ) \cos(\cdot) cos(⋅) 表示余弦相似度函数。 -
施加干扰相关一致性约束:基于计算得到的像素级相关性分布,通过最小化 L 1 L_1 L1 距离来施加干扰相关一致性约束:
L d c c ( G t , G s ) = E z i ∼ p ( z ) [ ∥ D j k t , l ( x , y ) − D j k s , l ( x , y ) ∥ 1 ] L_{dcc}(G_t, G_s) = \mathbb{E}_{z_i \sim p(z)} [\, \lVert D_{jk}^{t,l}(x, y) - D_{jk}^{s,l}(x, y) \rVert_1 \,] Ldcc(Gt,Gs)=Ezi∼p(z)[∥Djkt,l(x,y)−Djks,l(x,y)∥1]
这里 D j k l ( x , y ) D_{jk}^l(x, y) Djkl(x,y) 和 D j k l ( x , y ) D_{jk}^l(x, y) Djkl(x,y) 分别是对于目标域和源域生成器计算得到的像素级相关性矩阵。
通过这种方法,干扰相关一致性损失帮助模型在源域和目标域之间传递图像在小干扰下的空间变化特征,从而在少样本生成模型适应过程中保持图像身份的一致性和减少结构失真。
Disturbance correlation consistency loss计算过程的pytorch代码示例:
import torch import torch.nn.functional as F def pixelwise_spatial_correlation(feature_map_j, feature_map_k, delta): """ 计算像素级空间相互相关性 :param feature_map_j: 来自第一个样本的特征图 :param feature_map_k: 来自第二个样本的特征图 :param delta: 滑动窗口的宽度 :return: 相互相关性矩阵 """ C, H, W = feature_map_j.size() correlation_matrix = torch.zeros((H, W, H, W)) # 滑动窗口 for x in range(H): for y in range(W): region = feature_map_k[:, max(x - delta // 2, 0):min(x + delta // 2 + 1, H), max(y - delta // 2, 0):min(y + delta // 2 + 1, W)] region = region.view(C, -1).t() # 转置以计算余弦相似度 vector = feature_map_j[:, x, y].view(C, 1) cos_similarity = F.cosine_similarity(vector.t(), region) correlation_matrix[x, y] = F.softmax(cos_similarity.view(-1), dim=0) return correlation_matrix def disturbance_correlation_consistency_loss(source_feature_maps, target_feature_maps, delta): """ 计算干扰相关一致性损失 :param source_feature_maps: 源域生成图像的特征图列表 :param target_feature_maps: 目标域生成图像的特征图列表 :param delta: 滑动窗口的宽度 :return: 损失值 """ loss = 0.0 for source_feature_map, target_feature_map in zip(source_feature_maps, target_feature_maps): source_correlation = pixelwise_spatial_correlation(source_feature_map, source_feature_map, delta) target_correlation = pixelwise_spatial_correlation(target_feature_map, target_feature_map, delta) # 使用 L1 损失 loss += torch.mean(torch.abs(source_correlation - target_correlation)) return loss # 示例使用 # 假设 source_feature_maps 和 target_feature_maps 是从两个生成器得到的特征图列表 # loss = disturbance_correlation_consistency_loss(source_feature_maps, target_feature_maps, delta=3)
总的来说,跨域空间结构一致性损失在少样本生成模型适应中的作用是保持源域和目标域生成图像之间的结构一致性,同时传递图像的变化趋势,以防止生成图像中的身份退化和结构失真。
-
-
-
潜在空间压缩:为了放松跨域对齐并加速训练过程,引入了潜在空间压缩。这通过将潜在空间压缩到一个更接近目标域的子空间来实现,从而使得从子空间生成的合成对彼此更接近。
潜在空间压缩(Latent space compression)的工作原理是一种用于放松跨域对齐的方法,旨在将潜在空间压缩到更接近目标域的子空间。以下是它的详细工作原理:
-
目标和挑战:虽然空间结构一致性损失有助于对齐源域和目标域生成的图像,但直接对齐可能会导致源域属性的主导,从而减慢模型适应过程。
-
潜在空间的逆转:首先,使用 Image2StyleGAN 等技术,将目标域中的少量样本逆转到源域生成器 G s G_s Gs的 W + W^+ W+ 空间。这样,给定 n n n 个目标样本,可以在 l l l 层获得逆转的潜在代码集合 { w i l } i = 1 n \{w^l_i\}^n_{i=1} {wil}i=1n。
-
子空间构建:定义一个由逆转的潜在代码构成的 n n n 列矩阵 A l A^l Al,其中 A ∗ i l = w i l A^l_{*i} = w^l_i A∗il=wil,从而在 l l l 层获得一个子空间 X l X^l Xl。
-
潜在空间压缩:通过这种方式,潜在空间被压缩到一个更接近目标域的子空间,从而放松了跨域对齐。这有助于加速模型的适应过程,同时避免由于过度依赖源域属性而导致的模型适应缓慢。
-
加速训练过程:通过将生成的图像对从压缩的子空间中得到,潜在空间压缩有助于加速训练过程,使得从子空间生成的合成对在结构上更接近。
综上所述,潜在空间压缩是通过减少源域和目标域之间的直接对齐,将潜在空间引导至更接近目标域的子空间,以加快模型适应过程并避免结构失真。
Latent space compression工作原理的pytorch代码示例:
import torch def project_to_subspace(latent_code, subspace_matrix): """ 将潜在代码投影到子空间 :param latent_code: 潜在代码 :param subspace_matrix: 子空间投影矩阵 :return: 投影后的潜在代码 """ projection = subspace_matrix @ torch.inverse(subspace_matrix.T @ subspace_matrix) @ subspace_matrix.T return projection @ latent_code def latent_space_compression(latent_code, subspace_matrix, alpha): """ 潜在空间压缩 :param latent_code: 潜在代码 :param subspace_matrix: 子空间投影矩阵 :param alpha: 调制系数 :return: 调整后的潜在代码 """ projected_code = project_to_subspace(latent_code, subspace_matrix) compressed_code = alpha * projected_code + (1 - alpha) * latent_code return compressed_code # 示例使用 # 假设 latent_code 是从生成器得到的潜在代码,subspace_matrix 是子空间投影矩阵 # adjusted_latent_code = latent_space_compression(latent_code, subspace_matrix, alpha=0.5)
-
-
优化策略:
论文中提出的优化策略包括以下几个关键部分:
- 潜在空间的调制:首先,对于给定的输入噪声
z
j
z_j
zj,在
l
l
l 层对应的潜在代码
w
j
l
w_{jl}
wjl 通过调制系数
α
l
\alpha_l
αl 进行调制。这一步骤涉及到潜在空间压缩的应用,其中
w
j
l
w_{jl}
wjl 被投影到第
l
l
l 层的子平面
X
l
X_l
Xl 上。这通过最小二乘法实现,具体计算公式为:
w j l = A l ( A l T A l ) − 1 A l T w j l w_{jl} = A_l(A_l^T A_l)^{-1}A_l^T w_{jl} wjl=Al(AlTAl)−1AlTwjl - 潜在代码的调整:调整后的潜在代码
w
^
j
l
\hat{w}_{jl}
w^jl 通过以下公式计算:
w ^ j l = α l w j l + ( 1 − α l ) w j l \hat{w}_{jl} = \alpha_l w_{jl} + (1 - \alpha_l)w_{jl} w^jl=αlwjl+(1−αl)wjl
这里, α l \alpha_l αl 是调制系数,用于平衡原始潜在代码和子空间投影之间的贡献。 - 生成图像的比较:在论文中的图4中展示了使用不同潜在空间的输入潜在代码生成的图像。顶部是从原始潜在空间中采样的潜在代码,而底部是从压缩后的潜在空间中采样的潜在代码。
通过这种优化策略,论文有效地将源域生成器的潜在空间压缩到与目标域更接近的子空间,从而在保持图像质量的同时加快了模型适应过程。这种方法利用了源域和目标域之间的结构相似性,同时避免了过度依赖源域属性所带来的模型适应缓慢问题。
- 潜在空间的调制:首先,对于给定的输入噪声
z
j
z_j
zj,在
l
l
l 层对应的潜在代码
w
j
l
w_{jl}
wjl 通过调制系数
α
l
\alpha_l
αl 进行调制。这一步骤涉及到潜在空间压缩的应用,其中
w
j
l
w_{jl}
wjl 被投影到第
l
l
l 层的子平面
X
l
X_l
Xl 上。这通过最小二乘法实现,具体计算公式为:
-
结构一致性评分:为了更好地评估少样本生成模型适应性,提出了一个新的度量标准——结构一致性评分(SCS),用于衡量源域和目标域合成对的结构相似性。
Structural Consistency Score (SCS) 的计算公式是:
S C S ( G t ) = E z i ∼ p ( z ) [ 2 ∣ H ( G t ( z i ) ) ∩ H ( G s ( z i ) ) ∣ ∣ H ( G t ( z i ) ) ∣ + ∣ H ( G s ( z i ) ) ∣ ] SCS(G_t) = \mathbb{E}_{z_i \sim p(z)} \left[ \frac{2|H(G_t(z_i)) \cap H(G_s(z_i))|}{|H(G_t(z_i))| + |H(G_s(z_i))|} \right] SCS(Gt)=Ezi∼p(z)[∣H(Gt(zi))∣+∣H(Gs(zi))∣2∣H(Gt(zi))∩H(Gs(zi))∣]
这个公式计算的是生成的图像 G t ( z i ) G_t(z_i) Gt(zi)和 G s ( z i ) G_s(z_i) Gs(zi) 在结构上的一致性。其中 H H H 表示的是边缘检测函数 HED (Holistically-Nested Edge Detection),它用于提取图像的结构信息。分子中 ∣ H ( G t ( z i ) ) ∩ H ( G s ( z i ) ) ∣ |H(G_t(z_i)) \cap H(G_s(z_i))| ∣H(Gt(zi))∩H(Gs(zi))∣ 计算的是 G t G_t Gt 和 G s G_s Gs 生成的图像边缘图的交集,分母是这两个边缘图的并集。这个比例表达了两个边缘图共有边缘与总边缘的比例,从而反映了结构一致性的程度。
SCS 的值越高,说明目标域生成的图像 G t ( z i ) G_t(z_i) Gt(zi) 在结构上与源域生成的图像 G s ( z i ) G_s(z_i) Gs(zi) 保持的一致性越好,这表明在少样本适应过程中图像的结构信息得到了更好的保留。
综上所述,RSSA通过以上方法和策略实现从源域到目标域的有效适应,同时保持图像的结构一致性和身份信息。
3.总结
本文提出了一种新的少样本生成模型适应方法——放松空间结构对齐(Relaxed Spatial Structural Alignment, RSSA)。通过跨域空间结构一致性损失对源域和目标域的生成分布进行对齐,能够很好地保留和传递源域图像的固有结构信息和空间变化趋势到目标域。此外,原始潜在空间被压缩到接近目标域的狭窄子空间,这放松了跨域对齐并加速了目标域生成器的收敛速度。论文还设计了一个新的度量标准——结构一致性评分(SCS),用以评估生成图像的结构质量,可以作为当前少样本生成场景中度量标准的一个补充。
尽管该方法能够很好地处理极少量样本训练的设置,并生成引人注目的视觉结果,但它也有一些局限性。空间结构一致性损失对于一些抽象领域(例如阿梅代奥·莫迪利亚尼的画作,这些画作以面部的超现实伸长为特征)并不友好。尽管如此,作者们相信在不久的将来会提出更多数据高效的生成模型,这些模型的应用将反过来促进如少样本图像分类等一系列下游任务的发展。