McGan: Mean and Covariance Feature Matching GAN

Mroueh Y, Sercu T, Goel V, et al. McGan: Mean and Covariance Feature Matching GAN[J]. arXiv: Learning, 2017.

@article{mroueh2017mcgan:,
title={McGan: Mean and Covariance Feature Matching GAN},
author={Mroueh, Youssef and Sercu, Tom and Goel, Vaibhava},
journal={arXiv: Learning},
year={2017}}

利用均值和协方差构建IPM, 获得相应的mean GAN 和 covariance gan.

主要内容

IPM:
d F ( P , Q ) = sup ⁡ f ∈ F ∣ E x ∼ P f ( x ) − E x ∼ Q f ( x ) ∣ . d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} |\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x)|. dF(P,Q)=fFsupExPf(x)ExQf(x).
F \mathscr{F} F是对称空间, 即 f ∈ F → − f ∈ F f\in \mathscr{F} \rightarrow -f \in \mathscr{F} fFfF,可得
d F ( P , Q ) = sup ⁡ f ∈ F { E x ∼ P f ( x ) − E x ∼ Q f ( x ) } . d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} \big \{\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x) \big\}. dF(P,Q)=fFsup{ExPf(x)ExQf(x)}.

Mean Matching IPM

F v , w , p : = { f ( x ) = ⟨ v , Φ w ( x ) ⟩ ∣ v ∈ R m , ∥ v ∥ p ≤ 1 , Φ w : X → R m , w ∈ Ω } , \mathscr{F}_{v,w,p}:= \{f(x)=\langle v, \Phi_w(x) \rangle | v\in \mathbb{R}^m, \|v\|_p \le 1, \Phi_w:\mathcal{X} \rightarrow \mathbb{R}^m, w \in \Omega\}, Fv,w,p:={f(x)=v,Φw(x)vRm,vp1,Φw:XRm,wΩ},
其中 ∥ ⋅ ∥ p \|\cdot \|_p p表示 ℓ p \ell_p p范数, Φ w \Phi_w Φw往往用网络来表示, 我们可通过截断 w w w来使得 F v , w , p \mathscr{F}_{v,w,p} Fv,w,p为有界线性函数空间(有界从而使得后面推导中 sup ⁡ \sup sup成为 max ⁡ \max max).

在这里插入图片描述
其中
μ w ( P ) = E x ∼ P [ Φ w ( x ) ] ∈ R m . \mu_w(\mathbb{P})= \mathbb{E}_{x \sim \mathbb{P}} [\Phi_w(x)] \in \mathbb{R}^m. μw(P)=ExP[Φw(x)]Rm.

最后一个等式的成立是因为:
∥ x ∥ ∗ = max ⁡ { ⟨ v , x ⟩ ∣ ∥ v ∥ ≤ 1 } , \|x\|_* = \max \{\langle v, x \rangle | \|v\| \le 1\}, x=max{v,xv1},
∥ ⋅ ∥ p \| \cdot \|_p p的对偶范数是 ∥ ⋅ ∥ q , 1 p + 1 q = 1 \|\cdot\|_q, \frac{1}{p}+\frac{1}{q}=1 q,p1+q1=1.

prime

整个GAN的训练过程即为
min ⁡ g θ max ⁡ w ∈ Ω max ⁡ v , ∥ v ∥ p ≤ 1 L μ ( v , w , θ ) , (3) \tag{3} \min_{g_\theta} \max_{w \in \Omega} \max_{v, \|v\|_p \le 1} \mathscr{L}_{\mu} (v,w,\theta), gθminwΩmaxv,vp1maxLμ(v,w,θ),(3)
其中
L μ ( v , w , θ ) = ⟨ v , E x ∈ P r Φ w ( x ) − E z ∼ p ( z ) Φ w ( g θ ( z ) ) ⟩ . \mathscr{L}_{\mu} (v,w,\theta) = \langle v, \mathbb{E}_{x \in \mathbb{P}_r} \Phi_w(x) - \mathbb{E}_{z \sim p(z)} \Phi_w(g_{\theta} (z)) \rangle. Lμ(v,w,θ)=v,ExPrΦw(x)Ezp(z)Φw(gθ(z)).

估计形式为
在这里插入图片描述

dual

也有对应的dual形态
min ⁡ g θ max ⁡ w ∈ Ω ∥ μ w ( P r ) − μ w ( P θ ) ∥ q . (4) \tag{4} \min_{g_\theta} \max_{w \in \Omega} \|\mu_w(\mathbb{P}_r) - \mu_w (\mathbb{P}_{\theta})\|_q. gθminwΩmaxμw(Pr)μw(Pθ)q.(4)

在这里插入图片描述

Covariance Feature Matching IPM

F U , V , w : = { f ( x ) = ∑ j = 1 k ⟨ u j , Φ w ( x ) ⟩ ⟨ v j , Φ w ( x ) ⟩ , ⟨ u i , u j ⟩ = ⟨ v i , v j ⟩ = 0 , i ≠ j , e l s e   1 } , \mathscr{F}_{U, V,w} := \{f(x)= \sum_{j=1}^k \langle u_j, \Phi_w(x) \rangle \langle v_j, \Phi_w(x)\rangle, \langle u_i, u_j \rangle = \langle v_i, v_j \rangle =0, i \not = j, else \:1 \}, FU,V,w:={f(x)=j=1kuj,Φw(x)vj,Φw(x),ui,uj=vi,vj=0,i=j,else1},
等价于
F U , V , w : = { f ( x ) = ⟨ U T Φ w ( x ) , V T Φ w ( x ) ⟩ , U T U = I k , V T V = I k , w ∈ Ω } . \mathscr{F}_{U, V,w} := \{f(x)= \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle, U^TU=I_k, V^TV=I_k, w \in \Omega \}. FU,V,w:={f(x)=UTΦw(x),VTΦw(x),UTU=Ik,VTV=Ik,wΩ}.

并有
在这里插入图片描述

其中 [ A ] k [A]_k [A]k表示 A A A k k k阶近似, 如果 A = ∑ i σ i u i v i T A = \sum_i \sigma_iu_iv_i^T A=iσiuiviT, σ 1 ≥ σ 2 , … \sigma_1\ge \sigma_2,\ldots σ1σ2,, 则 [ A ] k = ∑ i = 1 k σ i u i v i T [A]_k=\sum_{i=1}^k \sigma_i u_iv_i^T [A]k=i=1kσiuiviT. O m , k : = { M ∈ R m × k ∣ M T M = I k } \mathcal{O}_{m,k} := \{M \in \mathbb{R}^{m \times k} | M^TM = I_k \} Om,k:={MRm×kMTM=Ik}, ∥ A ∥ ∗ = ∑ i σ i \|A\|_*=\sum_i \sigma_i A=iσi表示算子范数.

prime

min ⁡ g θ max ⁡ w ∈ Ω max ⁡ U , V ∈ P m , k L σ ( U , V , w , θ ) , (6) \tag{6} \min_{g_\theta} \max_{w \in \Omega} \max_{U,V \in \mathcal{P}_{m, k}} \mathscr{L}_{\sigma} (U, V,w,\theta), gθminwΩmaxU,VPm,kmaxLσ(U,V,w,θ),(6)
其中
L σ ( U , V , w , θ ) = E x ∼ P r ⟨ U T Φ w ( x ) , V T Φ w ( x ) ⟩ − E z ∼ p z ⟨ U T Φ w ( g θ ( z ) ) , V T Φ w ( g θ ( z ) ) ⟩ . \mathscr{L}_{\sigma} (U,V,w,\theta) = \mathbb{E}_{x \sim \mathbb{P}_r} \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle- \mathbb{E}_{z \sim p_z} \langle U^T \Phi_w(g_{\theta}(z)), V^T\Phi_w(g_{\theta}(z)) \rangle. Lσ(U,V,w,θ)=ExPrUTΦw(x),VTΦw(x)EzpzUTΦw(gθ(z)),VTΦw(gθ(z)).

采用下式估计

在这里插入图片描述

dual

min ⁡ g θ max ⁡ w ∈ Ω ∥ [ Σ w ( P r ) − Σ w ( P θ ) ] k ∥ ∗ . (7) \tag{7} \min_{g_{\theta}} \max_{w \in \Omega} \| [\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})]_k\|_*. gθminwΩmax[Σw(Pr)Σw(Pθ)]k.(7)

注: 既然 Σ w ( P r ) − Σ w ( P θ ) \Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta}) Σw(Pr)Σw(Pθ)是对称的, 为什么 U ≠ V U \not =V U=V? 因为虽然其对称, 但是并不(半)正定, 所以 v i = − u i v_i=-u_i vi=ui也是有可能的.

算法

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码

未经测试.



import torch
import torch.nn as nn
from torch.nn.functional import relu
from collections.abc import Callable



def preset(**kwargs):
    def decorator(func):
        def wrapper(*args, **nkwargs):
            nkwargs.update(kwargs)
            return func(*args, **nkwargs)
        wrapper.__doc__ = func.__doc__
        wrapper.__name__ = func.__name__
        return wrapper
    return decorator


class Meanmatch(nn.Module):

    def __init__(self, p, dim, dual=False, prj='l2'):
        super(Meanmatch, self).__init__()
        self.norm = p
        self.dual = dual
        if dual:
            self.dualnorm = self.norm
        else:
            self.init_weights(dim)
            self.projection = self.proj(prj)


    @property
    def dualnorm(self):
        return self.__dualnorm

    @dualnorm.setter
    def dualnorm(self, norm):
        if norm == 'inf':
            norm = float('inf')
        elif not isinstance(norm, float):
            raise ValueError("Invalid norm")

        p = 1 / (1 - 1 / norm)
        self.__dualnorm = preset(p=p, dim=1)(torch.norm)


    def init_weights(self, dim):
        self.weights = nn.Parameter(torch.rand((1, dim)),
                                    requires_grad=True)

    @staticmethod
    def _proj1(x):
        u = x.max()
        if u <= 1.:
            return x
        l = 0.
        c = (u + l) / 2
        while (u - l) > 1e-4:
            r = relu(x - c).sum()
            if r > 1.:
                l = c
            else:
                u = c
            c = (u + l) / 2
        return relu(x - c)

    @staticmethod
    def _proj2(x):
        return x / torch.norm(x)

    @staticmethod
    def _proj3(x):
        return x / torch.max(x)

    def proj(self, prj):
        if prj == "l1":
            return self._proj1
        elif prj == "l2":
            return self._proj2
        elif prj == "linf":
            return self._proj3
        else:
            assert isinstance(prj, Callable), "Invalid prj"
            return prj



    def forward(self, real, fake):
        temp = (real - fake).mean(dim=1)
        if self.dual:
            return self.dualnorm(temp)
        elif not self.training and self.dual:
            raise TypeError("just for training...")
        else:
            self.weights.data = self.projection(self.weights.data) #some diff here!!!!!!!!!!
            return self.weights @ temp



class Covmatch(nn.Module):

    def __init__(self, dim, k):
        super(Covmatch, self).__init__()
        self.init_weights(dim, k)

    def init_weights(self, dim, k):
        temp1 = torch.rand((dim, k))
        temp2 = torch.rand((dim, k))
        self.U = nn.Parameter(temp1, requires_grad=True)
        self.V = nn.Parameter(temp2, requires_grad=True)

    def qr(self, w):
        q, r = torch.qr(w)
        sign = r.diag().sign()
        return q * sign

    def update_weights(self):
        self.U.data = self.qr(self.U.data)
        self.V.data = self.qr(self.V.data)

    def forward(self, real, fake):
        self.update_weights()
        temp1 = real @ self.U
        temp2 = real @ self.V
        temp3 = fake @ self.U
        temp4 = fake @ self.V
        part1 = torch.trace(temp1 @ temp2.t()).mean()
        part2 = torch.trace(temp3 @ temp4.t()).mean()
        return part1 - part2


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值