import dgl
import torch as th
from dgl.nn import GlobalAttentionPooling
g1 = dgl.rand_graph(3,4)# g1 is a random graph with 3 nodes and 4 edges
g1_node_feats = th.rand(3,5)# feature size is 5
gate_nn = th.nn.Linear(5,1)# the gate layer that maps node feature to scalar
gap = GlobalAttentionPooling(gate_nn)# create a Global Attention Pooling layer
temp=gap(g1, g1_node_feats)print(temp)#(1,5)
import dgl
import torch as th
from dgl.nn import SetTransformerEncoder
g1 = dgl.rand_graph(3,4)# g1 is a random graph with 3 nodes and 4 edges
g1_node_feats = th.rand(3,5)# feature size is 5
set_trans_enc = SetTransformerEncoder(5,4,4,20)#(输入维度 (input_dim),编码器层数 (num_layers),注意力头数 (num_heads),隐藏层维度 (hidden_dim):)
temp=set_trans_enc (g1,g1_node_feats)print(temp)#(3,5)
import dgl
import torch as th
from dgl.nn import SortPooling
g1 = dgl.rand_graph(3,4)# g1 is a random graph with 3 nodes and 4 edges
g1_node_feats = th.rand(3,5)# feature size is 5print(g1_node_feats)
sortpool = SortPooling(k=2)# create a sort pooling layer
temp=sortpool(g1,g1_node_feats)print(temp)#(1,5*k)即(1,10)
SortPooling 层是全局池化的一种形式,它将图中的节点根据某种排序标准进行排序,然后选择前 k 个节点的邻居信息进行聚合。这种池化操作有助于捕获图中的全局结构信息,并且可以减少图的规模,从而减少计算量。
下面是 SortPooling 层的一些关键点:
排序标准:SortPooling 层可以根据节点的特征值进行排序。例如,可以基于节点的度(即邻居数量)、节点的嵌入向量的大小或其他自定义的标准。
选择节点:层会根据排序结果选择前 k 个节点。这些节点通常被认为是图中最重要的节点,可能是图中的中心节点或者是具有较高度的节点。
聚合操作:对于选定的节点,SortPooling 层会聚合它们的邻居信息。聚合可以是求和、平均、最大值等。
输出:SortPooling 层的输出是一个固定大小的向量,这个向量包含了聚合后的全局信息。
应用场景:这种池化操作在处理图数据时非常有用,尤其是在需要捕获全局结构信息的场景中,比如社交网络分析、分子结构分析等。
参数 k:k 是一个超参数,表示在排序后选择的节点数量。这个值可以根据具体任务和图的大小来调整。