Junction Tree Variational Autoencoder for Molecular Graph Generation
Year: 2018
Authors: Wengong Jin, Regina Barzilay, Tommi Jaakkola
Journal Name: ICML
Contributions
- 使用分子图自动设计分子结构
- 将整个任务分为编码(以连续方法表示分子)和解码(将连续的表示映射回分子图)
Junction Tree Variational Autoencoder
分子图和连接树提供了两个表示
z
=
[
z
T
,
z
G
]
\bm{z} = [\bm{z}_{\mathcal{T}}, \bm{z}_G]
z=[zT,zG] ,两者由编码器
q
(
z
T
∣
T
)
q(\bm{z}_{\mathcal{T}} | \mathcal{T})
q(zT∣T) 和
q
(
z
G
∣
G
)
q(\bm{z}_{G} | G)
q(zG∣G) 产生。两个解码器
p
(
T
∣
z
T
)
p(\mathcal{T} | \bm{z}_{\mathcal{T}})
p(T∣zT) 和
p
(
G
∣
T
,
z
G
)
p(G | \mathcal{T}, \bm{z}_{G})
p(G∣T,zG) 重构分子图。
Junction Tree
已知分子图 G = ( V , E ) G = (V, E) G=(V,E) ,连接树为 T G = ( V , E , X ) \mathcal{T}_G = (\mathcal{V}, \mathcal{E}, \mathcal{X}) TG=(V,E,X) ,其中 X \mathcal{X} X 为特征字典, V = { C 1 , . . . , C n } \mathcal{V} = \{ C_1, ..., C_n \} V={C1,...,Cn} 。 C i = ( V i , E i ) C_i = (V_i, E_i) Ci=(Vi,Ei) 为 G G G 的子结构,满足以下限制
- ∪ i V i = V \cup_i V_i = V ∪iVi=V , ∪ i E i = E \cup_i E_i = E ∪iEi=E
- 如果 C k C_k Ck 在从 C i C_i Ci 到 C j C_j Cj 的路径上, V i ∩ V j ⊆ V k V_i \cap V_j \subseteq V_k Vi∩Vj⊆Vk
Graph Encoder
每个节点
v
v
v 和边缘
(
u
,
v
)
∈
E
(u, v) \in E
(u,v)∈E 都有相对应的特征向量
x
v
\bm{x}_v
xv 和
x
u
v
\bm{x}_{uv}
xuv 。定义
v
u
v
\bm{v}_{uv}
vuv 为从
u
u
u 到
v
v
v 的信息
v
u
v
(
t
)
=
τ
(
W
1
g
x
u
+
W
2
g
x
u
v
+
W
3
g
∑
w
∈
N
(
u
)
∖
v
v
w
u
(
t
−
1
)
)
\bm{v}_{uv}^{(t)} = \tau(W_1^g \bm{x}_u + W_2^g \bm{x}_{uv} + W_3^g \sum_{w \in N(u) \setminus v}\bm{v}_{wu}^{(t-1)})
vuv(t)=τ(W1gxu+W2gxuv+W3gw∈N(u)∖v∑vwu(t−1))
其中,
τ
\tau
τ 为 RELU ,
v
u
v
(
t
)
\bm{v}_{uv}^{(t)}
vuv(t) 表示第
t
t
t 轮迭代后的信息,
v
u
v
(
0
)
=
0
\bm{v}_{uv}^{(0)} = 0
vuv(0)=0 。
T
T
T 轮迭代后,将信息聚合为每个节点的隐向量
h
u
=
τ
(
U
1
g
x
u
+
∑
v
∈
N
(
u
)
U
2
g
v
v
u
(
T
)
)
\bm{h}_u = \tau(U_1^g \bm{x}_u + \sum_{v \in N(u)} U_2^g \bm{v}_{vu}^{(T)})
hu=τ(U1gxu+v∈N(u)∑U2gvvu(T))
最终的图表示为 h G = ∑ i h i / ∣ V ∣ \bm{h}_G = \sum_{i} \bm{h}_i / |V| hG=∑ihi/∣V∣ 。 z G \bm{z}_G zG 从 N ( μ G , σ G ) \mathcal{N}(\bm{\mu}_G, \bm{\sigma}_G) N(μG,σG) 中采样, μ G \bm{\mu}_G μG 和 l o g σ G log \bm{\sigma}_G logσG 通过两个独立的仿射层根据 h G \bm{h}_G hG 计算得出。
Tree Encoder
对于每条边缘
(
C
i
,
C
j
)
(C_i, C_j)
(Ci,Cj) ,定义信息向量
m
i
j
\bm{m}_{ij}
mij 和
m
j
i
\bm{m}_{ji}
mji 。
m
i
j
=
G
R
U
(
x
i
,
{
m
k
i
}
k
∈
N
(
i
)
∖
j
)
\bm{m}_{ij} = GRU(\bm{x}_i, \{ \bm{m}_{ki} \}_{k \in N(i) \setminus j})
mij=GRU(xi,{mki}k∈N(i)∖j)
GRU 的结构如下所示
s
i
j
=
∑
k
∈
N
(
i
)
∖
j
m
k
i
z
i
j
=
σ
(
W
z
x
i
+
U
z
s
i
j
+
b
z
)
r
k
i
=
σ
(
W
r
x
i
+
U
r
m
i
j
+
b
r
)
m
~
i
j
=
t
a
n
h
(
W
x
i
+
U
∑
k
∈
N
(
i
)
∖
j
r
k
i
⊙
m
k
i
)
m
i
j
=
(
1
−
z
i
j
)
⊙
s
i
j
+
z
i
j
⊙
m
~
i
j
\bm{s}_{ij} = \sum_{k \in N(i) \setminus j} \bm{m}_{ki} \\ \bm{z}_{ij} = \sigma (W^z \bm{x}_i + U^z \bm{s}_{ij} + b^z) \\ \bm{r}_{ki} = \sigma(W^r \bm{x}_i + U^r \bm{m}_{ij} + b^r) \\ \widetilde{\bm{m}}_{ij} = tanh(W \bm{x}_i + U \sum_{k \in N(i) \setminus j} \bm{r}_{ki} \odot \bm{m}_{ki}) \\ \bm{m}_{ij} = (1 - \bm{z}_{ij}) \odot \bm{s}_{ij} + \bm{z}_{ij} \odot \widetilde{\bm{m}}_{ij}
sij=k∈N(i)∖j∑mkizij=σ(Wzxi+Uzsij+bz)rki=σ(Wrxi+Urmij+br)m
ij=tanh(Wxi+Uk∈N(i)∖j∑rki⊙mki)mij=(1−zij)⊙sij+zij⊙m
ij
其中,
σ
\sigma
σ 为 sigmoid 函数。信息传递之后,每个节点的隐向量
h
i
=
τ
(
W
o
x
i
+
∑
k
∈
N
(
i
)
U
o
m
k
i
)
\bm{h}_i = \tau(W^o \bm{x}_i + \sum_{k \in N(i)}U^o \bm{m}_{ki})
hi=τ(Woxi+k∈N(i)∑Uomki)
采样 z T \bm{z}_{\mathcal{T}} zT 的方法和图编码器类似。
Tree Decoder
解码过程在原分子的基础上,利用树采样继续扩展新的子结构,原分子的所有子结构均为根节点。
定义
E
~
t
\widetilde{\mathcal{E}}_t
E
t 为到
t
t
t 时刻为止已经采样的边缘,
h
i
t
j
t
\bm{h}_{i_t j_t}
hitjt 为采样过程中产生的信息。
h
i
t
j
t
=
G
R
U
(
x
i
t
,
{
h
k
i
t
}
(
k
,
i
t
)
∈
E
~
t
,
k
≠
j
t
)
\bm{h}_{i_t j_t} = GRU(\bm{x}_{i_t}, \{ \bm{h}_{k i_t} \}_{(k, i_t) \in \widetilde{\mathcal{E}}_t, k \neq j_t})
hitjt=GRU(xit,{hkit}(k,it)∈E
t,k=jt)
定义
p
t
p_t
pt 为当前叶节点是否继续扩展的概率
p
t
=
σ
(
u
d
⋅
τ
(
W
1
d
x
i
t
+
W
2
d
z
T
+
W
3
d
∑
(
k
,
i
t
)
∈
E
~
t
h
k
i
t
)
)
p_t = \sigma(u^d · \tau(W_1^d \bm{x}_{i_t} + W_2^d \bm{z}_{\mathcal{T}} + W_3^d \sum_{(k, i_t) \in \widetilde{\mathcal{E}}_t} \bm{h}_{k i_t}))
pt=σ(ud⋅τ(W1dxit+W2dzT+W3d(k,it)∈E
t∑hkit))
定义
q
j
=
s
o
f
t
m
a
x
(
U
l
τ
(
W
1
l
z
T
+
W
2
l
h
i
j
)
)
q_j = softmax(U^l \tau(W_1^l \bm{z}_{\mathcal{T}} + W_2^l \bm{h}_{ij}))
qj=softmax(Ulτ(W1lzT+W2lhij))
表示扩展节点
j
j
j 的特征
x
j
\bm{x}_j
xj 在特征字典
X
\mathcal{X}
X 中的概率。当
j
j
j 为根节点时,
h
i
j
=
0
\bm{h}_{ij} = 0
hij=0 。训练时采用 teacher forcing 最小化交叉熵损失
L
c
(
T
)
=
∑
t
L
d
(
p
t
,
p
^
t
)
+
∑
j
L
l
(
q
j
,
q
^
j
)
L_c(\mathcal{T}) = \sum_t L^d(p_t, \hat{p}_t) + \sum_j L^l(q_j, \hat{q}_j)
Lc(T)=t∑Ld(pt,p^t)+j∑Ll(qj,q^j)
Graph Decoder
因为相同的树所重构出的图并不唯一,定义
G
(
T
)
\mathcal{G}(\mathcal{T})
G(T) 为树
T
\mathcal{T}
T 所能重构的图的集合。
G
^
=
arg max
G
′
∈
G
(
T
)
f
a
(
G
′
)
\hat{G} = \argmax_{G' \in \mathcal{G}(\mathcal{T})} f^a(G')
G^=G′∈G(T)argmaxfa(G′)
其中,
f
a
f^a
fa 为评分函数。出于效率原因,作者按照树本身的解码顺序,一次扩展一个子结构进行计算。
假设根据树节点
C
j
C_j
Cj 新扩展的子结构为
C
i
C_i
Ci ,生成了子图
G
i
G_i
Gi ,子图所对应的向量表示为
h
G
i
\bm{h}_{G_i}
hGi ,评分函数为
f
a
(
G
i
)
=
h
G
i
⋅
z
G
f^a (G_i) = \bm{h}_{G_i} · \bm{z}_G
fa(Gi)=hGi⋅zG
定义
u
u
u 和
v
v
v 为
G
i
G_i
Gi 中的两个原子。如果
v
∈
C
i
v \in C_i
v∈Ci ,
α
v
=
i
\alpha_v = i
αv=i 。如果
v
∈
C
j
∖
C
i
v \in C_j \setminus C_i
v∈Cj∖Ci ,
α
v
=
j
\alpha_v = j
αv=j 。设立
α
v
\alpha_v
αv 是为了标注原子在树中的位置。仿照图编码器,定义
μ
u
v
\bm{\mu}_{uv}
μuv 为从
u
u
u 到
v
v
v 的信息
μ
u
v
(
t
)
=
τ
(
W
1
a
x
u
+
W
2
a
x
u
v
+
W
3
a
μ
~
u
v
(
t
−
1
)
)
μ
~
u
v
(
t
−
1
)
=
{
∑
w
∈
N
(
u
)
∖
v
μ
w
u
(
t
−
1
)
,
α
u
=
α
v
,
m
^
α
u
α
v
+
∑
w
∈
N
(
u
)
∖
v
μ
w
u
(
t
−
1
)
,
α
u
≠
α
v
.
\bm{\mu}_{uv}^{(t)} = \tau(W_1^a \bm{x}_u + W_2^a \bm{x}_{uv} + W_3^a \widetilde{\bm{\mu}}_{uv}^{(t-1)}) \\ \widetilde{\bm{\mu}}_{uv}^{(t-1)} = \left\{ \begin{aligned} \sum_{w \in N(u) \setminus v} \bm{\mu}_{wu}^{(t-1)} & , & \alpha_u = \alpha_v, \\ \hat{\bm{m}}_{\alpha_u \alpha_v} + \sum_{w \in N(u) \setminus v} \bm{\mu}_{wu}^{(t-1)} & , & \alpha_u \neq \alpha_v. \end{aligned} \right.
μuv(t)=τ(W1axu+W2axuv+W3aμ
uv(t−1))μ
uv(t−1)=⎩⎪⎪⎪⎨⎪⎪⎪⎧w∈N(u)∖v∑μwu(t−1)m^αuαv+w∈N(u)∖v∑μwu(t−1),,αu=αv,αu=αv.
计算
h
G
i
\bm{h}_{G_i}
hGi 的方法与图编码器相同。
学习图解码器参数以最大化在每个树节点处预测地面真实图 G 的正确子图 G i 的对数似然
该过程的损失函数为
L
g
(
G
)
=
∑
i
[
f
a
(
G
i
)
−
l
o
g
∑
G
i
′
∈
G
i
e
x
p
(
f
a
(
G
i
′
)
)
]
L_g(G) = \sum_i \Big[ f^a(G_i) - log \sum_{G_i' \in \mathcal{G}_i} exp(f^a(G_i')) \Big]
Lg(G)=i∑[fa(Gi)−logGi′∈Gi∑exp(fa(Gi′))]
其中,
i
i
i 为树的节点,
G
i
G_i
Gi 为正确子图。
以我的理解,
l
o
g
∑
G
i
′
∈
G
i
e
x
p
(
f
a
(
G
i
′
)
)
log \sum_{G_i' \in \mathcal{G}_i} exp(f^a(G_i'))
log∑Gi′∈Giexp(fa(Gi′)) 放大了较大
f
a
(
G
i
′
)
f^a(G_i')
fa(Gi′) 的影响,减少了较小
f
a
(
G
i
′
)
f^a(G_i')
fa(Gi′) 的影响。所以,该损失函数倾向于使正确子图的分数无穷大,错误子图的分数为 0 ,但这样的话
f
a
(
G
i
)
f^a (G_i)
fa(Gi) 直接使用内积计算相似度是否不太合理?
Results