论文题目:Cross-domain Correspondence Learning for Exemplar-based Image Translation
论文主页:https://panzhang0212.github.io/CoCosNet/
论文链接:https://arxiv.org/abs/2004.05571
代码链接:https://github.com/microsoft/CoCosNet
摘要
本文提供了一种图像翻译的通用框架,它从输入的语义图像合成真实的照片图像。与常规不同的是这个框架可以再输入一个 exemplar image,以这个 exemplar image的风格来输出最后的真实照片图像。这个exemplar图像给输出图像更多的限制,也提供了更多信息。
这个框架主要由两部分组成,一是解决跨域语义对应的Cross domain correspondence Network,二是解决翻译生成图像的Translation network。传统方法的理论只能处理自然图像直接的关系,无法处理跨域图像,但本框架可以处理跨域图像的问题。
Cross domain correspondence Network:
首先建立了位于不同领域的input和exemplar image之间的对应关系,并对exemplar image进行了相应的扭曲,使其语义与input一致。具体是把两个域的图像映射到一个中间域,找到对应关系,从而扭曲exemplar image.
input图像xA属于A域,exemplar图像yB属于B域,作者通过把xA 和yB放入feature pyramid network(利用FPN方法)提取特征,转化为中间域S的xS和yS.
其中
θ
F
\theta_{\mathcal{F}}
θF是需要学习的参数
此步骤损失函数为:
L domain ℓ 1 = ∥ F A → S ( x A ) − F B → S ( x B ) ∥ 1 \mathcal{L}_{\text {domain }}^{\ell_{1}}=\left\|\mathcal{F}_{A \rightarrow S}\left(x_{A}\right)-\mathcal{F}_{B \rightarrow S}\left(x_{B}\right)\right\|_{1} Ldomain ℓ1=∥FA→S(xA)−FB→S(xB)∥1
由于XA和YB是不同域图像,但包含相同语义,他们转化到S域之后应当尽量对其,故损失函数为使两者在S域中的映射之间的差别。应使这个差异最小。
xA和yB都映射到域S之后,计算一个S域中他们俩的相关矩阵,然后通过softmax加权选择yB中最相关的像素。
M
(
u
,
v
)
=
x
^
S
(
u
)
T
y
^
S
(
v
)
∥
x
^
S
(
u
)
∥
∥
y
^
S
(
v
)
∥
\mathcal{M}(u, v)=\frac{\hat{x}_{S}(u)^{T} \hat{y}_{S}(v)}{\left\|\hat{x}_{S}(u)\right\|\left\|\hat{y}_{S}(v)\right\|}
M(u,v)=∥x^S(u)∥∥y^S(v)∥x^S(u)Ty^S(v)
r y → x ( u ) = ∑ v softmax v ( α M ( u , v ) ) ⋅ y B ( v ) r_{y \rightarrow x}(u)=\sum_{v} \operatorname{softmax}_{v}(\alpha \mathcal{M}(u, v)) \cdot y_{B}(v) ry→x(u)=∑vsoftmaxv(αM(u,v))⋅yB(v)
损失函数为: L r e g = ∥ r y → x → y − y B ∥ 1 \mathcal{L}_{r e g}=\left\|r_{y \rightarrow x \rightarrow y}-y_{B}\right\|_{1} Lreg=∥ry→x→y−yB∥1
Translation Network:
把扭曲的exemplar image合成输出图像。从一个固定的常量z开始,通过卷积逐步扭曲图像的风格信息。
α h , w i ( r y → x ) × F c , h , w i − μ h , w i σ h , w i + β h , w i ( r y → x ) \alpha_{h, w}^{i}\left(r_{y \rightarrow x}\right) \times \frac{F_{c, h, w}^{i}-\mu_{h, w}^{i}}{\sigma_{h, w}^{i}}+\beta_{h, w}^{i}\left(r_{y \rightarrow x}\right) αh,wi(ry→x)×σh,wiFc,h,wi−μh,wi+βh,wi(ry→x)
α i , β i = T i ( r y → x ; θ T ) \alpha^{i}, \beta^{i}=\mathcal{T}_{i}\left(r_{y \rightarrow x} ; \theta_{\mathcal{T}}\right) αi,βi=Ti(ry→x;θT)
最终生成图像:
x
^
B
=
G
(
z
,
T
i
(
r
y
→
x
;
θ
T
)
;
θ
G
)
\hat{x}_{B}=\mathcal{G}\left(z, \mathcal{T}_{i}\left(r_{y \rightarrow x} ; \theta_{\mathcal{T}}\right) ; \theta_{\mathcal{G}}\right)
x^B=G(z,Ti(ry→x;θT);θG)
最终网络为七层,得到输出图片。
另外的一些损失函数:
第一个是伪参考图像对损失,xB作为真实值,xB’是xB的变形,保持图片内容不变,如翻转等。如果吧xB’作为exemplar image,xA作为input,那么生成图像应接近xB。故损失函数为:
L feat = ∑ l λ l ∥ ϕ l ( G ( x A , x B ′ ) ) − ϕ l ( x B ) ∥ 1 \mathcal{L}_{\text {feat }}=\sum_{l} \lambda_{l}\left\|\phi_{l}\left(\mathcal{G}\left(x_{A}, x_{B}^{\prime}\right)\right)-\phi_{l}\left(x_{B}\right)\right\|_{1} Lfeat =∑lλl∥ϕl(G(xA,xB′))−ϕl(xB)∥1
第二个是参考图像转换损失,其中包含两项,perceptual loss和contextual loss。
perceptual loss:
L perc = ∥ ϕ l ( x ^ B ) − ϕ l ( x B ) ∥ 1 \mathcal{L}_{\text {perc }}=\left\|\phi_{l}\left(\hat{x}_{B}\right)-\phi_{l}\left(x_{B}\right)\right\|_{1} Lperc =∥ϕl(x^B)−ϕl(xB)∥1
contextual loss:
L context = ∑ l ω l [ − log ( 1 n l ∑ i max j A l ( ϕ i l ( x ^ B ) , ϕ j l ( y B ) ) ) ] \mathcal{L}_{\text {context }}=\sum_{l} \omega_{l}\left[-\log \left(\frac{1}{n_{l}} \sum_{i} \max _{j} A^{l}\left(\phi_{i}^{l}\left(\hat{x}_{B}\right), \phi_{j}^{l}\left(y_{B}\right)\right)\right)\right] Lcontext =∑lωl[−log(nl1∑imaxjAl(ϕil(x^B),ϕjl(yB)))]
最后是Adversarial loss:
L a d v D = − E [ h ( D ( y B ) ) ] − E [ h ( D ( G ( x A , y B ) ) ) ] \mathcal{L}_{a d v}^{\mathcal{D}}=-\mathbb{E}\left[h\left(\mathcal{D}\left(y_{B}\right)\right)\right]-\mathbb{E}\left[h\left(\mathcal{D}\left(\mathcal{G}\left(x_{A}, y_{B}\right)\right)\right)\right] LadvD=−E[h(D(yB))]−E[h(D(G(xA,yB)))]
L a d v G = − E [ D ( G ( x A , y B ) ) ] \mathcal{L}_{a d v}^{\mathcal{G}}=-\mathbb{E}\left[\mathcal{D}\left(\mathcal{G}\left(x_{A}, y_{B}\right)\right)\right] LadvG=−E[D(G(xA,yB))]
最终损失函数为:
L θ = min F , T , G max D ψ 1 L feat + ψ 2 L perc + ψ 3 L context + ψ 4 L a d v G + ψ 5 L domain ℓ 1 + ψ 6 L reg \begin{aligned} \mathcal{L}_{\theta}=\min _{\mathcal{F}, \mathcal{T}, \mathcal{G}} & \max _{\mathcal{D}} \psi_{1} \mathcal{L}_{\text {feat }}+\psi_{2} \mathcal{L}_{\text {perc }}+\psi_{3} \mathcal{L}_{\text {context }} \\ &+\psi_{4} \mathcal{L}_{a d v}^{\mathcal{G}}+\psi_{5} \mathcal{L}_{\text {domain }}^{\ell_{1}}+\psi_{6} \mathcal{L}_{\text {reg }}\end{aligned} Lθ=F,T,GminDmaxψ1Lfeat +ψ2Lperc +ψ3Lcontext +ψ4LadvG+ψ5Ldomain ℓ1+ψ6Lreg
实验
生成图像对比:
跨领域的相关度
利用correlation matrix可以计算输入语义图像和输入参考风格图像之间不同点的对应关系
图像编辑
给定一张图像及其对应的mask,对语义mask进行修改,再将原图像作为参考风格图像
方法限制
示例图像中的两辆不同颜色汽车同时与input中的汽车相对应,方法可能会产生混合颜色伪影,与现实不符;此外,在多对一映射(第二行)的 情况下,多个实例(图中的枕头)可能使用相同的样式
另外,相关矩阵等计算非常占用GPU内存,使得这个方法很难用在高分辨率的图像上。