第七周.02.Tree LSTM代码讲解


本文内容整理自深度之眼《GNN核心能力培养计划》
公式输入请参考: 在线Latex公式
之前的论文带读看这里: 第七周.直播.Tree LSTM带读
官网的代码看这里:
https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html

任务和数据集介绍

数据集是斯坦福的Stanford Sentiment Treebank(SST DATASET)
数据集官网有语法树的示例:https://nlp.stanford.edu/sentiment/treebank.html,贴一个,其他自己去官网看:
在这里插入图片描述
语法树中,非叶子节点不包含单词(用PAD_WORD表示,没有表征,训练和测试时embedding初始化设置为0,但是非叶节点参与消息汇聚的操作),最后的标签共有5个分类:Very negative, negative, neutral, positive, and very positive

导入数据

为了演示,这里对原数据集进行缩减,用的是tiny模式,该模式下数据集只包含5个句子。
所有单词用的独热编码来表示。

from collections import namedtuple

import dgl
from dgl.data.tree import SSTDataset


SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])

# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny')  # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes

vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word

a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():#节点转list
    if token != trainset.PAD_WORD:#判断是否叶子节点
        print(inv_vocab[token], end=" ")#通过id转word后打印

打印结果:
the rock is destined to be the 21st century 's new `` conan ‘’ and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .
这句话和上面的语法树图是一一对应的。

Step 1: Batching

题外话:

graphviz的安装有坑啊,哪位解决了评论告诉我一下
先下载安装 Graphviz:http://www.graphviz.org/download/
安装过程记得勾选加入系统变量,然后可以cmd里面测试一下,建立一个dot文件,加入以下代码:

//dot a.dot -Tpng -o a.png  -Gsplines=line  
digraph G {
	a -> b;//边
	b -> c;//边
	subgraph x{
		rank=same;//同一行接下个节点
		b->d;
	}
	subgraph y{
		rank = same;//同一行接下个节点
		d->e;
	}
	subgraph z{
		//rank=same;
		c->e;
	}
}

在dot文件相同目录下运行:

dot a.dot -Tpng -o a.png  -Gsplines=line 

得到结果如下:
在这里插入图片描述
表示安装成功
然后安装 PyGraphviz,到这里
https://www.lfd.uci.edu/~gohlke/pythonlibs/#pygraphviz
注意里面的数字对应python的版本,不要下错了,不然安装不了
在这里插入图片描述
下载后用pip装之,但是运行报错:

No module named _graphviz

先不管了,反正不画图不影响,先注释掉吧。
我的是win10的系统。
借用一下官网的图:
在这里插入图片描述

代码

import networkx as nx
import matplotlib.pyplot as plt

graph = dgl.batch(tiny_sst)
#def plot_tree(g):
#    # this plot requires pygraphviz package
#    pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
#    nx.draw(g, pos, with_labels=False, node_size=10,
#            node_color=[[.5, .5, .5]], arrowsize=4)
#    plt.show()

#plot_tree(graph.to_networkx())

这里的batch是将数据集中的子图的邻接矩阵按对角线进行排列,这样可以把所有子图放到一个大的邻接矩阵里面进行计算。

Step 2: Tree-LSTM cell with message-passing APIs

原文有提出两种Tree-LSTM :
Child-Sum Tree-LSTMs
N-ary Tree-LSTMs
这里的实现主要针对二叉树的语法树,用N-ary Tree-LSTMs来处理。
N-ary Tree-LSTMs中,每一个节点 j j j 包含一个隐层表征 h j h_j hj(公式6)和一个记忆单元 c j c_j cj(公式5),节点 j j j吃两个输入,一个是孩子节点的输入 x j x_j xj以及两个孩子的隐层输入 h j l , 1 ≤ l ≤ N h_{jl}, 1\leq l\leq N hjl,1lN (看公式1,这里二叉树N=2)

i j = σ ( W ( i ) x j + ∑ l = 1 N U l ( i ) h j l + b ( i ) ) , ( 1 ) f j k = σ ( W ( f ) x j + ∑ l = 1 N U k l ( f ) h j l + b ( f ) ) , ( 2 ) o j = σ ( W ( o ) x j + ∑ l = 1 N U l ( o ) h j l + b ( o ) ) , ( 3 ) u j = tanh ( W ( u ) x j + ∑ l = 1 N U l ( u ) h j l + b ( u ) ) , ( 4 ) c j = i j ⊙ u j + ∑ l = 1 N f j l ⊙ c j l , ( 5 ) h j = o j ⋅ tanh ( c j ) , ( 6 ) \begin{aligned}i_j & = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ f_{jk} & = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ o_j & = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ u_j & = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ c_j & = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ h_j & = o_j \cdot \textrm{tanh}(c_j), &(6) \\\end{aligned} ijfjkojujcjhj=σ(W(i)xj+l=1NUl(i)hjl+b(i)),=σ(W(f)xj+l=1NUkl(f)hjl+b(f)),=σ(W(o)xj+l=1NUl(o)hjl+b(o)),=tanh(W(u)xj+l=1NUl(u)hjl+b(u)),=ijuj+l=1Nfjlcjl,=ojtanh(cj),(1)(2)(3)(4)(5)(6)
下面是本次课程重点,就是结合上节博客内容,如何来用三个核心函数来完成消息的传递操作。
message_func
reduce_func
apply_node_func

import torch as th
import torch.nn as nn

class TreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__()
        
        #公式134,对于W,输入是x_j,因此输入维度是x_size,输出由于把iou三个门都进行了concate,因此输出维度是3 * h_size
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        #公式134,对于U,输入是两个孩子的隐层表征,因此输入维度是2 * h_size,输出由于把iou三个门都进行了concate,因此输出维度是3 * h_size
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        #公式134,偏置要和前面两项做运算,因此维度要和前面的输出维度一样:3 * h_size
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
        #这里的输入维度要和公式2中的h_cat(看下面的reduce_func)进行运算,所以是:2 * h_size
        #这里的输出维度要和公式5中的c_{jl}进行element_wise的乘法,所以是:2 * h_size
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)

    #汇聚什么?汇聚孩子信息c,还有孩子节点的隐层信息h
    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    #如何汇聚?这里由于叶子节点没有孩子,因此不需要做reduce_func
    #其他非叶子节点都是算公式5的第二项
    def reduce_func(self, nodes):
        # concatenate h_jl for equation (1), (2), (3), (4)        
        # nodes.mailbox['h']的维度:节点数量×邻居数量(二叉树是2)×隐层大小h_size(本例子是256)
        # 然后通过view做变化,保留节点数量第0维,把后面两个维度合并:节点数量×1024(就是2 * h_size)
        h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
        # 对应equation (2)
        #这里U要和h_cat进行运算后在变换回nodes.mailbox['h']的维度:节点数量×邻居数量(二叉树是2)×隐层大小h_size(本例子是256)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
        # second term of equation (5)
        # 这里算公式5中的第二项,注意f和nodes.mailbox['c']都是三维,然后在第二个维度上进行求和
        c = th.sum(f * nodes.mailbox['c'], 1)
        # 最后返回公式134对应的参数U,以及c
        return {'iou': self.U_iou(h_cat), 'c': c}
    
    #汇聚后更新节点表征需要什么操作?
    def apply_node_func(self, nodes):        
        # 这里的公式(1), (3), (4)中非叶子节点表征是0,因此省略第一项,
        iou = nodes.data['iou'] + self.b_iou
        # 把拼接起来的iou分割开
        i, o, u = th.chunk(iou, 3, 1)
        # 分别进行公式(1), (3), (4)中最外面的非线性变换
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
        # equation (5)
        c = i * u + nodes.data['c']
        # equation (6)
        h = o * th.tanh(c)
        return {'h' : h, 'c' : c}

Step 3: Define traversal

在这里插入图片描述
从动图可以看到,Tree LSTM的节点更新的充分必要条件是:其所有孩子都已经更新完成。
其实就是每次去更新入度为0的叶子节点即可,然后从集合里面去掉已经更新过的节点。
就是按层从下到上遍历树,DGL里面有自带的遍历函数:

# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
print('Traversing one tree:')
print(dgl.topological_nodes_generator(trv_a_tree))

# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(trv_graph))

遍历结果:
在这里插入图片描述
下面还有一小段例子,我没跑出来,有懂的没。

import dgl.function as fn
import torch as th

trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(traversal_order,
                     message_func=fn.copy_src('a', 'a'),
                     reduce_func=fn.sum('a', 'a'))


print(traversal_order)#咋看放了a作为属性的传递效果?

# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)

Tree LSTM模型

class TreeLSTM(nn.Module):
    def __init__(self,
                 num_vocabs,
                 x_size,
                 h_size,
                 num_classes,
                 dropout,
                 pretrained_emb=None):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:#这里可以使用预训练词向量
            print('Using glove')
            self.embedding.weight.data.copy_(pretrained_emb)
            self.embedding.weight.requires_grad = True
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(h_size, num_classes)
        self.cell = TreeLSTMCell(x_size, h_size)

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        # to heterogenous graph
        g = dgl.graph(g.edges())
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        #叶子节点没有入度,因此message_func和reduce_func都可以忽略,直接apply_node_func
        g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # prop_nodes_topo是根据我们指定的拓扑顺序来进行消息传递
        dgl.prop_nodes_topo(g,
                            message_func=self.cell.message_func,
                            reduce_func=self.cell.reduce_func,
                            apply_node_func=self.cell.apply_node_func)
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits

main函数

from torch.utils.data import DataLoader
import torch.nn.functional as F

device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10

# create the model
model = TreeLSTM(trainset.num_vocabs,
                 x_size,
                 h_size,
                 trainset.num_classes,
                 dropout)
print(model)

# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay)

def batcher(dev):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(graph=batch_trees,
                        mask=batch_trees.ndata['mask'].to(device),
                        wordid=batch_trees.ndata['x'].to(device),
                        label=batch_trees.ndata['y'].to(device))
    return batcher_dev

train_loader = DataLoader(dataset=tiny_sst,
                          batch_size=5,
                          collate_fn=batcher(device),
                          shuffle=False,
                          num_workers=0)

# training loop
for epoch in range(epochs):
    for step, batch in enumerate(train_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = th.zeros((n, h_size))
        c = th.zeros((n, h_size))
        logits = model(batch, h, c)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, batch.label, reduction='sum') 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = th.argmax(logits, 1)
        acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
        print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
            epoch, step, loss.item(), acc))

模型的参数维度:
在这里插入图片描述
结果:
在这里插入图片描述

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

oldmao_2000

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值