谷歌的论文,基于seq2seq+VAE编码并生成手绘序列
https://arxiv.org/pdf/1704.03477.pdf
本文主要是论文的概述翻译,记录
文章目录
1.Introduction
- 生成模型的发展:GANs(Generative Adversarial Networks)、VI(Variational Inference)、AR(Autoregressive)
- 当前多用于处理图像像素数据(pixel images),而人的理解是矢量序列的,本文即提出用于矢量图像的生成模型(手绘草图)
- 文本贡献:适用于线序列的条件/非条件生成模型框架;提出的sketch-RNN可以生成矢量格式的有意义的图像;开发了一种可以使训练更鲁棒的方法;将矢量图映射到了潜在空间;最后讨论了本文可能的应用领域
2.Related Work
- 对于模仿绘画,有通过既定的文件执行绘画的机器人、和基于强化学习的方法,并非生成
- 神经网络用于生成的多是栅格图像;早期对于线的有HMM模型方法;最近有基于RNN的Mixture Densify Network及其改进的方法用于生成连续数据点、和汉字
- 最近有用Sequence-to-Sequence 模型结合VAE(变分自编码器)的用于英语语言编码到潜在空间的研究
- 最后提了一些现有的公开数据集
3.方法
3.1 数据集
- 取自QuickDraw应用,有20s以内绘制出的草图,上百类,每类有70k的训练样本,及2.5k验证,2.5k测试
- 数据序列组织为[dx,dy,p1,p2,p3],分别表示x、y方向的变化,p1、p2、p3表示继续绘制、结束子序列、结束绘图三个状态
3.2 Sketch-RNN
- 基本结构为Sequence-to-Sequence 变分自编码器
- 其中编码器用双向RNN,以草图作为输入,潜在空间向量作为输出(用常规的VAE,先出 μ \mu μ和 σ \sigma σ,再按正态分布算 z z z)。
h
→
=
e
n
c
o
d
e
→
(
S
)
,
h
←
=
e
n
c
o
d
e
←
(
S
r
e
v
e
r
s
e
)
,
h
=
[
h
→
;
h
←
]
h_\rightarrow=encode_\rightarrow(S), h_\leftarrow=encode_\leftarrow(S_{reverse}), h=[h_\rightarrow;h_\leftarrow]
h→=encode→(S),h←=encode←(Sreverse),h=[h→;h←]
μ
=
W
μ
h
+
b
μ
,
σ
^
=
W
σ
h
+
b
σ
,
σ
=
e
x
p
(
σ
^
2
)
,
z
=
μ
+
σ
⊙
N
(
0
,
1
)
\mu=W_{\mu}h+b_{\mu}, \hat{\sigma}=W_{\sigma}h+b_{\sigma}, \sigma=exp(\frac{\hat{\sigma}}{2}), z=\mu+\sigma \odot \mathcal{N}(0,1)
μ=Wμh+bμ,σ^=Wσh+bσ,σ=exp(2σ^),z=μ+σ⊙N(0,1)
-
解码器用自回归RNN,以序列后一点作为当前点的输出;由于之前是双向RNN编码,所以z先过一个tanh得到解码器的初始状态
[ h 0 ; c 0 ] = t a n h ( W z z + b z ) [h_0;c_0]=tanh(W_zz+b_z) [h0;c0]=tanh(Wzz+bz) -
S 0 S_0 S0定义为(0,0,1,0,0)
-
(dx,dy)通过M元正态分布的高斯混合模型(GMM)计算概率,(q1,q2,q3)作为类别来计算,M也是一个类别分布,是GMM的混合权重
p ( △ x , △ y ) = ∑ j = 1 M N ( △ x , △ y ∣ μ x , j , μ y , j , σ x , j , σ y , j , ρ x y , j ) , w h e r e ∑ j = 1 M Π j = 1 p(\triangle x,\triangle y)=\sum_{j=1}^{M} \mathcal{N}(\triangle x,\triangle y | \mu_{x,j},\mu_{y,j},\sigma_{x,j},\sigma_{y,j},\rho_{xy,j}),where \ \sum_{j=1}^{M}\Pi_j=1 p(△x,△y)=j=1∑MN(△x,△y∣μx,j,μy,j,σx,j,σy,j,ρxy,j),where j=1∑MΠj=1 -
因此,解码器的输出维度为5M+M+3,即6M+3维
x i = [ S i − 1 ; z ] , [ h i ; c i ] = f o r w a r d ( x i , [ h i − 1 ; c i − 1 ] ) , y i = W y h i + b y , y i ∈ R 6 M + 3 [ ( Π ^ μ x μ y σ ^ x σ ^ y ρ ^ x y ) 1 . . . ( Π ^ μ x μ y σ ^ x σ ^ y ρ ^ x y ) M ( q 1 ^ q 2 ^ q 3 ^ ) ] = y i x_i=[S_{i-1};z],[h_i;c_i]=forward(x_i,[h_{i-1};c_{i-1}]),y_i=W_yh_i+b_y,y_i\in \mathbb{R}^{6M+3} \\ \ \\ [(\hat{\Pi} \mu_x \mu_y \hat{\sigma}_x \hat{\sigma}_y \hat{\rho}_xy)_1...(\hat{\Pi} \mu_x \mu_y \hat{\sigma}_x \hat{\sigma}_y \hat{\rho}_xy)_M(\hat{q_1}\hat{q_2}\hat{q_3})]=y_i xi=[Si−1;z],[hi;ci]=forward(xi,[hi−1;ci−1]),yi=Wyhi+by,yi∈R6M+3 [(Π^μxμyσ^xσ^yρ^xy)1...(Π^μxμyσ^xσ^yρ^xy)M(q1^q2^q3^)]=yi -
为了使标准差值非负,使用exp和tanh来约束为-1~1之间
σ x = exp ( σ ^ x ) , σ y = exp ( σ ^ y ) , ρ x y = tanh ( ρ ^ x y ) \sigma_x=\exp(\hat{\sigma}_x),\sigma_y=\exp(\hat{\sigma}_y),\rho_{xy}=\tanh(\hat{\rho}_{xy}) σx=exp(σ^x),σy=exp(σ^y),ρxy=tanh(ρ^xy) -
类别分布概率计算为
q k = exp ( q ^ k ) ∑ j = 1 3 exp ( q ^ j ) , k ∈ { 1 , 2 , 3 } Π k = exp ( Π ^ k ) ∑ j = 1 M exp ( Π ^ J ) , k ∈ { 1 , . . . , M } q_k=\frac{\exp(\hat{q}_k)}{\sum_{j=1}^{3}\exp(\hat{q}_j)},k\in\{1,2,3\}\\ \Pi_k=\frac{\exp(\hat{\Pi}_k)}{\sum_{j=1}^{M}\exp(\hat{\Pi}_J)},k\in\{1,...,M\} qk=∑j=13exp(q^j)exp(q^k),k∈{1,2,3}Πk=∑j=1Mexp(Π^J)exp(Π^k),k∈{1,...,M} -
本方法存在着p1,p2,p3状态数据不平衡问题,通用的方法是样本加权,但这样并不适用于多类别数据集,本文的解决方法是设定最大长度,实际结束后的都用(0,0,0,0,1)来标记
-
在训练阶段,我们每次获取本时间步的结果。而在生成阶段,我们将本时间步的输出结果作为下一时间步的输入,直到输出的p3=1或达到最大长度为止
-
设置了一个温度参数 τ \tau τ,来增加序列的随机性, τ \tau τ取值在0~1之间,约接近0,模型结果越确定。
3.3Unconditional Generation
- 我们可以只训练模型的解码器,没有编码器、没有输入、没有潜在空间向量,设置初始隐藏状态为0,那么会得到一个纯生成的模型
3.4 Training
-
模型采用变分自编码器的方法,其损失函数由重建损失 L R L_R LR和KL散度损失 L K L L_{KL} LKL组成。
-
对于重建损失,分别为偏移量 ( △ x , △ y ) (\triangle x,\triangle y) (△x,△y)的对数损失 L s L_s Ls和画笔状态 ( p 1 , p 2 , p 3 ) (p_1,p_2,p_3) (p1,p2,p3)的对数损失 L p L_p Lp(注: N s N_s Ns为序列实际长度)
L s = − 1 N max ∑ i = 1 N s log ( ∑ j = 1 M Π j , i N ( △ x i , △ y i ∣ μ x , j , i , μ y , j , i , σ x , j , i , σ y , j , i , ρ x y , j , i ) ) L p = − 1 N max ∑ i = 1 N m a x ∑ k = 1 3 p k , i log ( q k , i ) , L R = L s + L p L_s=-\frac{1}{N_{\max}} \sum_{i=1}^{N_s} \log(\sum_{j=1}^{M} \Pi_{j,i} \mathcal{N}(\triangle x_i,\triangle y_i|\mu_{x,j,i},\mu_{y,j,i},\sigma_{x,j,i},\sigma_{y,j,i},\rho_{xy,j,i}))\\ L_p=-\frac{1}{N_{\max}} \sum_{i=1}^{N_{max}}\sum_{k=1}^{3} p_{k,i} \log(q_{k,i}),L_R=L_s+L_p Ls=−Nmax1i=1∑Nslog(j=1∑MΠj,iN(△xi,△yi∣μx,j,i,μy,j,i,σx,j,i,σy,j,i,ρxy,j,i))Lp=−Nmax1i=1∑Nmaxk=1∑3pk,ilog(qk,i),LR=Ls+Lp -
KL散度损失度量的是潜在向量 z z z和独立同分布的高斯向量之间的差异,(可以使不同草图在潜在空间中距离更近,使插值有意义)
L K L = − 1 2 N z ( 1 + σ ^ − μ 2 − exp ( σ ^ ) ) L o s s = L R + w K L L K L L_{KL}=-\frac{1}{2N_z} (1+\hat{\sigma}-\mu^2-\exp(\hat{\sigma}) )\\ Loss=L_R+w_{KL}L_{KL} LKL=−2Nz1(1+σ^−μ2−exp(σ^))Loss=LR+wKLLKL
4.Experiments
- 分别尝试了多类别和单类别,以及不同 w K L w_{KL} wKL
- 编码器用双向LSTM,解码器用HyperLSTM
4.1 Conditional Reconstruction
- 单独训练猫/猪的数据,设置不同的温度
τ
\tau
τ,
τ
\tau
τ越小重建约稳定;另外模型训练后能起到一定的修正作用;即使是输入个牙刷,重建时也会同时保留二者的特征
4.2 Latent Space Interpolation
4.3 Sketch Drawing Analogies(类比)
- 通过在潜在空间插值得到草图的变化过程,且设置更高的
w
K
L
w_{KL}
wKL能够产生更好的数据流形关系;
4.4 Predicting Different Endings of Incomplete Sketches
- 一个应用点,根据初始的几笔,来补充完整的草图
5.Applications and Future Work;6.结论;略
附录
-
数据预处理时,将偏移量 ( △ x , △ y ) (\triangle x,\triangle y) (△x,△y)缩放为方差为1的大小。不执行0均值操作(因为均值本身就很小)。
-
在计算KL散度损失时,引入退火算法,效果更好
-
模型设置方面,编码器512个神经元,解码器2048个。用M=20的混合组成;设置recurrent dropout保留90%;batch_size=100;Adam,学习率为0.0001,梯度裁剪为1;KL_{min}=0.2,R=0.99999(模拟退火的参数?)
-
点的数量不能多于300个,本文用了道格拉斯普克法算法将数据点压缩到200个以下
-
对于复杂的图像,重建效果较差,且更倾向于圆滑的效果
-
类别数不宜过多
-
其他略