项目实训八

本文深入剖析了树形解码器的节点分类和分支预测模块。节点分类模块由两个GRU、注意力机制和分类器构成,用于预测节点类别。分支预测模块则关注节点间的空间关系。解码过程通过GRU和注意力机制获取上下文信息,最终通过最大激活函数和softmax计算节点预测概率。文章还介绍了损失函数和未来编码实现计划。
摘要由CSDN通过智能技术生成

树形解码器的decoder部分

在解码器的每个解码步骤中,树解码器需要预测当前子节点的信息,包括子节点和子节点的分支,节点的分支表示节点与子节点之间的空间关系。我们可以通过节点类别和分支来逐步构建一棵数学树。如下图所示,为了解耦分类和空间关系预测,我们在解码器中设计了两个模块:节点分类模块和分支预测模块。
在这里插入图片描述

节点类模块

节点类模块主要包括两个GRU,一个注意力机制和一个分类器,我们首先使用两个嵌入层去获得父节点 p t p_{t} pt高维的特征向量 e t p \mathbf{e}_{t}^{\mathrm{p}} etp e t r \mathbf{e}_{t}^{\mathrm{r}} etr以及其空间关系 r t r_{t} rt。节点解码器 s t − 1 n o d e \mathbf{s}_{t-1}^{\mathrm{node}} st1node的先前隐藏状态被视为 G R U 1 n o d e \mathbf{GRU}_{1}^{\mathrm{node}} GRU1node层的先前隐藏状态。父节点的嵌入层 e t p \mathbf{e}_{t}^{\mathrm{p}} etp和空间关系节点的嵌入层 e t r \mathbf{e}_{t}^{\mathrm{r}} etr一起作为 G R U 1 n o d e \mathbf{GRU}_{1}^{\mathrm{node}} GRU1node的输入,然后就可以得到 G R U 1 n o d e \mathbf{GRU}_{1}^{\mathrm{node}} GRU1node的当前隐藏状态 S ~ t node  \widetilde{\mathbf{S}}_{t}^{\text {node }} S tnode 
e t p = Emd ⁡ node  ( p t ) \mathbf{e}_{t}^{\mathrm{p}}=\operatorname{Emd}_{\text {node }}\left(p_{t}\right) etp=Emdnode (pt) e t r = E m d r e ( r t ) \mathbf{e}_{t}^{\mathrm{r}}=\mathrm{Emd}_{\mathrm{re}}\left(r_{t}\right) etr=Emdre(rt) s ~ t node  = GRU ⁡ 1 node  ( [ e t p , e t r ] , s t − 1 node  ) \widetilde{\mathbf{s}}_{t}^{\text {node }}=\operatorname{GRU}_{1}^{\text {node }}\left(\left[\mathbf{e}_{t}^{\mathrm{p}}, \mathbf{e}_{t}^{\mathrm{r}}\right], \mathbf{s}_{t-1}^{\text {node }}\right) s tnode =GRU1node ([etp,etr],st1node )
然后,节点注意力机制模块 f att  node  f_{\text {att }}^{\text {node }} fatt node 被用来在特征映射A上的注意可能性 α t node  {\alpha}_{t}^{\text {node }} αtnode ,通过计算在A上的权重之和来获得节点上下文向量 c t node  \mathbf{c}_{t}^{\text {node }} ctnode ,这里使用 S ~ t node  \widetilde{\mathbf{S}}_{t}^{\text {node }} S tnode 作为query并且A作为key和value。
α t node  = f a t t node  ( A , s ~ t node  ) \boldsymbol{\alpha}_{t}^{\text {node }}=f_{\mathrm{att}}^{\text {node }}\left(\mathbf{A}, \widetilde{\mathbf{s}}_{t}^{\text {node }}\right) αtnode =fattnode (A,s tnode ) c t node  = ∑ α t i node  a i \mathbf{c}_{t}^{\text {node }}=\sum \alpha_{t i}^{\text {node }} \mathbf{a}_{i} ctnode =αtinode ai
函数 f att  node  f_{\text {att }}^{\text {node }} fatt node 如下:
F node  = Q node  ∗ ∑ l = 1 t − 1 α l node  \mathbf{F}^{\text {node }}=\mathbf{Q}^{\text {node }} * \sum_{l=1}^{t-1} \boldsymbol{\alpha}_{l}^{\text {node }} Fnode =Qnode l=1t1αlnode  e t i node  = V node  T tanh ⁡ ( W att  node  s ~ t node  + U att  node  a i + U ^ F node  f i node  ) e_{t i}^{\text {node }}=V_{\text {node }}^{\mathrm{T}} \tanh \left(\mathbf{W}_{\text {att }}^{\text {node }} \tilde{\mathbf{s}}_{t}^{\text {node }}+\mathbf{U}_{\text {att }}^{\text {node }} \mathbf{a}_{i}+\hat{\mathbf{U}}_{\mathrm{F}}^{\text {node }} \mathbf{f}_{i}^{\text {node }}\right) etinode =Vnode Ttanh(Watt node s~tnode +Uatt node ai+U^Fnode finode ) α t i node  = exp ⁡ ( e t i node  ) ∑ k exp ⁡ ( e t k node  ) \alpha_{t i}^{\text {node }}=\frac{\exp \left(e_{t i}^{\text {node }}\right)}{\sum_{k} \exp \left(e_{t k}^{\text {node }}\right)} αtinode =kexp(etknode )exp(etinode )
α t i node  \alpha_{t i}^{\text {node }} αtinode 表示第t步的第i个元素的节点的可能性, e t i node  e_{t i}^{\text {node }} etinode 表示第i步的输出, f i n o d e \mathbf{f}_{i}^{\mathrm{node}} finode表示函数 F node  \mathbf{F}^{\text {node }} Fnode 的第i个元素,这是以前的注意模块,为了避免过度解析或者解析不足的问题,其余为学习参数。
接着,使用 c t node  \mathbf{c}_{t}^{\text {node }} ctnode  s ~ t node  \widetilde{\mathbf{s}}_{t}^{\text {node }} s tnode 作为 G R U 2 node  \mathbf{GRU}_{2}^{\text {node }} GRU2node 的输入来计算预测模块隐藏状态 s t node  \mathbf{s}_{t}^{\text {node }} stnode 
s t node  = GRU ⁡ 2 node  ( c t node  , s ~ t node  ) \mathbf{s}_{t}^{\text {node }}=\operatorname{GRU}_{2}^{\text {node }}\left(\mathbf{c}_{t}^{\text {node }}, \widetilde{\mathbf{s}}_{t}^{\text {node }}\right) stnode =GRU2node (ctnode ,s tnode )最后通过父节点 e t p \mathbf{e}_{t}^{\text {p}} etp,与父节点的关系 e t r \mathbf{e}_{t}^{\text {r}} etr,节点的隐藏状态 s t node  \mathbf{s}_{t}^{\text {node }} stnode 以及上下文向量 c t node  \mathbf{c}_{t}^{\text {node }} ctnode 的聚合来计算预测节点 o t node  \mathbf{o}_{t}^{\text {node }} otnode 的可能性: h t node  = maxout ⁡ ( W 1 node  [ e t p , e t r , s t node  , c t node  ] ) \mathbf{h}_{t}^{\text {node }}=\operatorname{maxout}\left(\mathbf{W}_{1}^{\text {node }}\left[\mathbf{e}_{t}^{\mathrm{p}}, \mathbf{e}_{t}^{\mathrm{r}}, \mathbf{s}_{t}^{\text {node }}, \mathbf{c}_{t}^{\text {node }}\right]\right) htnode =maxout(W1node [etp,etr,stnode ,ctnode ]) o t node  = softmax ⁡ ( W 2 node  h t node  ) \mathbf{o}_{t}^{\text {node }}=\operatorname{softmax}\left(\mathbf{W}_{2}^{\text {node }} \mathbf{h}_{t}^{\text {node }}\right) otnode =softmax(W2node htnode )其中W参数为学习参数。
我们使用cross-entropy函数来计算分类的损失函数 L node  = − ∑ log ⁡ o t node  ⋅ n t \mathcal{L}_{\text {node }}=-\sum \log \mathbf{o}_{t}^{\text {node }} \cdot \mathbf{n}_{t} Lnode =logotnode nt n t \mathbf{n}_{t} nt表示第t步节点真实值的独热向量。

上周只进行了模块分析和一部分代码的编写,这周将完成者模块代码的编写。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值