广义少样本学习之对齐VAE
文章目录
论文下载
VAE 变分自编码器
变分自编码器是一种生成模型。它包含两部分,编码器和解码器。首先,编码器在样本 x x x 上学习一个样本特定的正态分布;然后,从这个正态分布中随机采样一个变量;最后,解码器将这个变量作为输入,然后生成一个样本 x ^ \hat x x^。
模型
Cross and Distribution Aligned VAE
basic M VAE losses VAE损失
L V A E = ∑ i M E q ϕ ( z ∣ x ) [ log p 0 ( x ( i ) ∣ z ) ] − β D K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ) ) (2) \mathcal{L}_{VAE} = \sum_i^M \mathbb{E}_{q_{\phi (z|x)}} [\log{p_0(x^{(i)}|z)}] \\ -\beta D_{KL}(q_{\phi}(z|x^{(i)})||p_{\theta}(z)) \tag{2} LVAE=i∑MEqϕ(z∣x)[logp0(x(i)∣z)]−βDKL(qϕ(z∣x(i))∣∣pθ(z))(2)
Cross-Alignment (CA) Loss 跨域对齐损失
L C A = ∑ i M ∑ j ≠ i M ∣ x ( j ) − D j ( E i ( x ( i ) ) ) ∣ (3) \mathcal{L}_{CA} = \sum_i^M \sum_{j \neq i}^M |x^{(j)} - D_j(E_i(x^{(i)}))| \tag{3} LCA=i∑Mj=i∑M∣x(j)−Dj(Ei(x(i)))∣(3)
Distribution-Alignment (DA) Loss 分布对齐损失
分布i和分布j的2-Wasserstein 距离的闭形式解如下:
W i j = [ ∣ ∣ μ i − μ j ∣ ∣ 2 2 + T r ( ∑ i ) + T r ( ∑ j ) − 2 ( ∑ i 1 2 ∑ i ∑ j 1 2 ) 1 2 ] 1 2 (4) W_{ij} = [||\mu_i - \mu_j||_2^2\\ + Tr(\sum_i) + Tr(\sum_j) - 2 (\sum_i^{\frac{1}{2}} \sum_i \sum_j^{\frac{1}{2}})^{\frac{1}{2}}]^{\frac{1}{2}} \tag{4} Wij=[∣∣μi−μj∣∣22+Tr(i∑)+Tr(j∑)−2(i∑21i∑j∑21)21]21(4)
由于编码器预测对角协方差矩阵,这是交换的,这个距离可以简化:
W i j = ( ∣ ∣ μ i − μ j ∣ ∣ 2 2 + ∣ ∣ ∑ i 1 2 − ∑ j 1 2 ∣ ∣ F r o b e n i u s 2 ) 1 2 (5) W_{ij} = (||\mu_i - \mu_j||_2^2 + ||\sum_i^{\frac{1}{2}} - \sum_j^{\frac{1}{2}}||_{Frobenius}^{2})^{\frac{1}{2}} \tag{5} Wij=(∣∣μi−μj∣∣22+∣∣i∑21−j∑21∣∣Frobenius2)21(5)
所以,对于M个域DA损失如下:
L
D
A
=
∑
i
M
∑
j
≠
i
M
W
i
j
(6)
\mathcal{L}_{DA} = \sum_i^M \sum_{j \neq i}^M W_{ij} \tag{6}
LDA=i∑Mj=i∑MWij(6)
CADA-VAE loss
L C A D A − V A E = L V A E + γ L C A + δ L D A (7) \mathcal{L}_{CADA-VAE} = \mathcal{L}_{VAE} + \gamma \mathcal{L}_{CA} + \delta \mathcal{L}_{DA} \tag{7} LCADA−VAE=LVAE+γLCA+δLDA(7)