@article{mroueh2017mcgan:,
title={McGan: Mean and Covariance Feature Matching GAN},
author={Mroueh, Youssef and Sercu, Tom and Goel, Vaibhava},
journal={arXiv: Learning},
year={2017}}
概
利用均值和协方差构建IPM, 获得相应的mean GAN 和 covariance gan.
主要内容
IPM:
d
F
(
P
,
Q
)
=
sup
f
∈
F
∣
E
x
∼
P
f
(
x
)
−
E
x
∼
Q
f
(
x
)
∣
.
d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} |\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x)|.
dF(P,Q)=f∈Fsup∣Ex∼Pf(x)−Ex∼Qf(x)∣.
当
F
\mathscr{F}
F是对称空间, 即
f
∈
F
→
−
f
∈
F
f\in \mathscr{F} \rightarrow -f \in \mathscr{F}
f∈F→−f∈F,可得
d
F
(
P
,
Q
)
=
sup
f
∈
F
{
E
x
∼
P
f
(
x
)
−
E
x
∼
Q
f
(
x
)
}
.
d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} \big \{\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x) \big\}.
dF(P,Q)=f∈Fsup{Ex∼Pf(x)−Ex∼Qf(x)}.
Mean Matching IPM
F
v
,
w
,
p
:
=
{
f
(
x
)
=
⟨
v
,
Φ
w
(
x
)
⟩
∣
v
∈
R
m
,
∥
v
∥
p
≤
1
,
Φ
w
:
X
→
R
m
,
w
∈
Ω
}
,
\mathscr{F}_{v,w,p}:= \{f(x)=\langle v, \Phi_w(x) \rangle | v\in \mathbb{R}^m, \|v\|_p \le 1, \Phi_w:\mathcal{X} \rightarrow \mathbb{R}^m, w \in \Omega\},
Fv,w,p:={f(x)=⟨v,Φw(x)⟩∣v∈Rm,∥v∥p≤1,Φw:X→Rm,w∈Ω},
其中
∥
⋅
∥
p
\|\cdot \|_p
∥⋅∥p表示
ℓ
p
\ell_p
ℓp范数,
Φ
w
\Phi_w
Φw往往用网络来表示, 我们可通过截断
w
w
w来使得
F
v
,
w
,
p
\mathscr{F}_{v,w,p}
Fv,w,p为有界线性函数空间(有界从而使得后面推导中
sup
\sup
sup成为
max
\max
max).
其中
μ
w
(
P
)
=
E
x
∼
P
[
Φ
w
(
x
)
]
∈
R
m
.
\mu_w(\mathbb{P})= \mathbb{E}_{x \sim \mathbb{P}} [\Phi_w(x)] \in \mathbb{R}^m.
μw(P)=Ex∼P[Φw(x)]∈Rm.
最后一个等式的成立是因为:
∥
x
∥
∗
=
max
{
⟨
v
,
x
⟩
∣
∥
v
∥
≤
1
}
,
\|x\|_* = \max \{\langle v, x \rangle | \|v\| \le 1\},
∥x∥∗=max{⟨v,x⟩∣∥v∥≤1},
又
∥
⋅
∥
p
\| \cdot \|_p
∥⋅∥p的对偶范数是
∥
⋅
∥
q
,
1
p
+
1
q
=
1
\|\cdot\|_q, \frac{1}{p}+\frac{1}{q}=1
∥⋅∥q,p1+q1=1.
prime
整个GAN的训练过程即为
min
g
θ
max
w
∈
Ω
max
v
,
∥
v
∥
p
≤
1
L
μ
(
v
,
w
,
θ
)
,
(3)
\tag{3} \min_{g_\theta} \max_{w \in \Omega} \max_{v, \|v\|_p \le 1} \mathscr{L}_{\mu} (v,w,\theta),
gθminw∈Ωmaxv,∥v∥p≤1maxLμ(v,w,θ),(3)
其中
L
μ
(
v
,
w
,
θ
)
=
⟨
v
,
E
x
∈
P
r
Φ
w
(
x
)
−
E
z
∼
p
(
z
)
Φ
w
(
g
θ
(
z
)
)
⟩
.
\mathscr{L}_{\mu} (v,w,\theta) = \langle v, \mathbb{E}_{x \in \mathbb{P}_r} \Phi_w(x) - \mathbb{E}_{z \sim p(z)} \Phi_w(g_{\theta} (z)) \rangle.
Lμ(v,w,θ)=⟨v,Ex∈PrΦw(x)−Ez∼p(z)Φw(gθ(z))⟩.
估计形式为
dual
也有对应的dual形态
min
g
θ
max
w
∈
Ω
∥
μ
w
(
P
r
)
−
μ
w
(
P
θ
)
∥
q
.
(4)
\tag{4} \min_{g_\theta} \max_{w \in \Omega} \|\mu_w(\mathbb{P}_r) - \mu_w (\mathbb{P}_{\theta})\|_q.
gθminw∈Ωmax∥μw(Pr)−μw(Pθ)∥q.(4)
Covariance Feature Matching IPM
F
U
,
V
,
w
:
=
{
f
(
x
)
=
∑
j
=
1
k
⟨
u
j
,
Φ
w
(
x
)
⟩
⟨
v
j
,
Φ
w
(
x
)
⟩
,
⟨
u
i
,
u
j
⟩
=
⟨
v
i
,
v
j
⟩
=
0
,
i
≠
j
,
e
l
s
e
1
}
,
\mathscr{F}_{U, V,w} := \{f(x)= \sum_{j=1}^k \langle u_j, \Phi_w(x) \rangle \langle v_j, \Phi_w(x)\rangle, \langle u_i, u_j \rangle = \langle v_i, v_j \rangle =0, i \not = j, else \:1 \},
FU,V,w:={f(x)=j=1∑k⟨uj,Φw(x)⟩⟨vj,Φw(x)⟩,⟨ui,uj⟩=⟨vi,vj⟩=0,i=j,else1},
等价于
F
U
,
V
,
w
:
=
{
f
(
x
)
=
⟨
U
T
Φ
w
(
x
)
,
V
T
Φ
w
(
x
)
⟩
,
U
T
U
=
I
k
,
V
T
V
=
I
k
,
w
∈
Ω
}
.
\mathscr{F}_{U, V,w} := \{f(x)= \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle, U^TU=I_k, V^TV=I_k, w \in \Omega \}.
FU,V,w:={f(x)=⟨UTΦw(x),VTΦw(x)⟩,UTU=Ik,VTV=Ik,w∈Ω}.
并有
其中 [ A ] k [A]_k [A]k表示 A A A的 k k k阶近似, 如果 A = ∑ i σ i u i v i T A = \sum_i \sigma_iu_iv_i^T A=∑iσiuiviT, σ 1 ≥ σ 2 , … \sigma_1\ge \sigma_2,\ldots σ1≥σ2,…, 则 [ A ] k = ∑ i = 1 k σ i u i v i T [A]_k=\sum_{i=1}^k \sigma_i u_iv_i^T [A]k=∑i=1kσiuiviT. O m , k : = { M ∈ R m × k ∣ M T M = I k } \mathcal{O}_{m,k} := \{M \in \mathbb{R}^{m \times k} | M^TM = I_k \} Om,k:={M∈Rm×k∣MTM=Ik}, ∥ A ∥ ∗ = ∑ i σ i \|A\|_*=\sum_i \sigma_i ∥A∥∗=∑iσi表示算子范数.
prime
min
g
θ
max
w
∈
Ω
max
U
,
V
∈
P
m
,
k
L
σ
(
U
,
V
,
w
,
θ
)
,
(6)
\tag{6} \min_{g_\theta} \max_{w \in \Omega} \max_{U,V \in \mathcal{P}_{m, k}} \mathscr{L}_{\sigma} (U, V,w,\theta),
gθminw∈ΩmaxU,V∈Pm,kmaxLσ(U,V,w,θ),(6)
其中
L
σ
(
U
,
V
,
w
,
θ
)
=
E
x
∼
P
r
⟨
U
T
Φ
w
(
x
)
,
V
T
Φ
w
(
x
)
⟩
−
E
z
∼
p
z
⟨
U
T
Φ
w
(
g
θ
(
z
)
)
,
V
T
Φ
w
(
g
θ
(
z
)
)
⟩
.
\mathscr{L}_{\sigma} (U,V,w,\theta) = \mathbb{E}_{x \sim \mathbb{P}_r} \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle- \mathbb{E}_{z \sim p_z} \langle U^T \Phi_w(g_{\theta}(z)), V^T\Phi_w(g_{\theta}(z)) \rangle.
Lσ(U,V,w,θ)=Ex∼Pr⟨UTΦw(x),VTΦw(x)⟩−Ez∼pz⟨UTΦw(gθ(z)),VTΦw(gθ(z))⟩.
采用下式估计
dual
min g θ max w ∈ Ω ∥ [ Σ w ( P r ) − Σ w ( P θ ) ] k ∥ ∗ . (7) \tag{7} \min_{g_{\theta}} \max_{w \in \Omega} \| [\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})]_k\|_*. gθminw∈Ωmax∥[Σw(Pr)−Σw(Pθ)]k∥∗.(7)
注: 既然 Σ w ( P r ) − Σ w ( P θ ) \Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta}) Σw(Pr)−Σw(Pθ)是对称的, 为什么 U ≠ V U \not =V U=V? 因为虽然其对称, 但是并不(半)正定, 所以 v i = − u i v_i=-u_i vi=−ui也是有可能的.
算法
代码
未经测试.
import torch
import torch.nn as nn
from torch.nn.functional import relu
from collections.abc import Callable
def preset(**kwargs):
def decorator(func):
def wrapper(*args, **nkwargs):
nkwargs.update(kwargs)
return func(*args, **nkwargs)
wrapper.__doc__ = func.__doc__
wrapper.__name__ = func.__name__
return wrapper
return decorator
class Meanmatch(nn.Module):
def __init__(self, p, dim, dual=False, prj='l2'):
super(Meanmatch, self).__init__()
self.norm = p
self.dual = dual
if dual:
self.dualnorm = self.norm
else:
self.init_weights(dim)
self.projection = self.proj(prj)
@property
def dualnorm(self):
return self.__dualnorm
@dualnorm.setter
def dualnorm(self, norm):
if norm == 'inf':
norm = float('inf')
elif not isinstance(norm, float):
raise ValueError("Invalid norm")
p = 1 / (1 - 1 / norm)
self.__dualnorm = preset(p=p, dim=1)(torch.norm)
def init_weights(self, dim):
self.weights = nn.Parameter(torch.rand((1, dim)),
requires_grad=True)
@staticmethod
def _proj1(x):
u = x.max()
if u <= 1.:
return x
l = 0.
c = (u + l) / 2
while (u - l) > 1e-4:
r = relu(x - c).sum()
if r > 1.:
l = c
else:
u = c
c = (u + l) / 2
return relu(x - c)
@staticmethod
def _proj2(x):
return x / torch.norm(x)
@staticmethod
def _proj3(x):
return x / torch.max(x)
def proj(self, prj):
if prj == "l1":
return self._proj1
elif prj == "l2":
return self._proj2
elif prj == "linf":
return self._proj3
else:
assert isinstance(prj, Callable), "Invalid prj"
return prj
def forward(self, real, fake):
temp = (real - fake).mean(dim=1)
if self.dual:
return self.dualnorm(temp)
elif not self.training and self.dual:
raise TypeError("just for training...")
else:
self.weights.data = self.projection(self.weights.data) #some diff here!!!!!!!!!!
return self.weights @ temp
class Covmatch(nn.Module):
def __init__(self, dim, k):
super(Covmatch, self).__init__()
self.init_weights(dim, k)
def init_weights(self, dim, k):
temp1 = torch.rand((dim, k))
temp2 = torch.rand((dim, k))
self.U = nn.Parameter(temp1, requires_grad=True)
self.V = nn.Parameter(temp2, requires_grad=True)
def qr(self, w):
q, r = torch.qr(w)
sign = r.diag().sign()
return q * sign
def update_weights(self):
self.U.data = self.qr(self.U.data)
self.V.data = self.qr(self.V.data)
def forward(self, real, fake):
self.update_weights()
temp1 = real @ self.U
temp2 = real @ self.V
temp3 = fake @ self.U
temp4 = fake @ self.V
part1 = torch.trace(temp1 @ temp2.t()).mean()
part2 = torch.trace(temp3 @ temp4.t()).mean()
return part1 - part2