飞桨pgl----初学GCN

更多PGL学习资料请前往paddle官网课程学习https://aistudio.baidu.com/aistudio/education/group/info/1956

一、安装pgl

首先安装paddlepaddle1.85,然后用下述命令安装pgl。
pip install paddlepaddle==1.8.5

二、使用pgl创建一张简单的图

import pgl
from pgl import graph  # 导入 PGL 中的图模块
import paddle.fluid as fluid # 导入飞桨框架
import numpy as np

def build_graph():
    # 定义图中的节点数目,我们使用数字来表示图中的每个节点
    num_nodes = 10

    # 定义图中的边集
    edge_list = [(2, 0), (2, 1), (3, 1),(4, 0), (5, 0),
             (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
             (7, 2), (7, 3), (8, 0), (9, 7)]

    # 随机初始化节点特征,特征维度为 d
    d = 16
    feature = np.random.randn(num_nodes, d).astype("float32")

    # 随机地为每条边赋值一个权重
    edge_feature = np.random.randn(len(edge_list), 1).astype("float32")

    # 创建图对象,最多四个输入
    g = graph.Graph(num_nodes = num_nodes,
                    edges = edge_list,
                    node_feat = {'feature':feature},
                    edge_feat ={'edge_feature': edge_feature})

    return g

g = build_graph()

此时g则代表已创建的图,可以打印其信息出来。
g.num_nodes # 图中10 节点个数
g.num_edges # 图中14 边个数

三、定义消息传输机制

def model_layer(gw, nfeat, efeat, hidden_size, name, activation):
    '''
    gw: GraphWrapper 图数据容器,用于在定义模型的时候使用,后续训练时再feed入真实数据
    nfeat: 节点特征
    efeat: 边权重
    hidden_size: 模型隐藏层维度
    activation: 使用的激活函数
    '''

    # 定义 send 函数
    def send_func(src_feat, dst_feat, edge_feat):
        # 将源节点的节点特征和边权重共同作为消息发送
        return src_feat['h'] * edge_feat['e']

    # 定义 recv 函数
    def recv_func(feat):
        # 目标节点接收源节点消息,采用 sum 的聚合方式
        return fluid.layers.sequence_pool(feat, pool_type='sum')

    # 触发消息传递机制
    msg = gw.send(send_func, nfeat_list=[('h', nfeat)], efeat_list=[('e', efeat)])
    output = gw.recv(msg, recv_func)
    output = fluid.layers.fc(output,
                    size=hidden_size,
                    bias_attr=False,
                    act=activation,
                    name=name)
    return output

然后将信息传输机制叠加即可构建一个简单的GCN网络,并加入损失函数进行训练。

class Model(object):
    def __init__(self, graph):
        """
        graph: 我们前面创建好的图
        """
        # 创建 GraphWrapper 图数据容器,用于在定义模型的时候使用,后续训练时再feed入真实数据
        self.gw = pgl.graph_wrapper.GraphWrapper(name='graph',
                    node_feat=graph.node_feat_info(),
                    edge_feat=graph.edge_feat_info())
        # 作用同 GraphWrapper,此处用作节点标签的容器
        self.node_label = fluid.layers.data("node_label", shape=[None, 1],
                    dtype="float32", append_batch_size=False)

    def build_model(self):
        # 定义两层model_layer
        output = model_layer(self.gw, 
                             self.gw.node_feat['feature'], 
                             self.gw.edge_feat['edge_feature'],
                             hidden_size=8, 
                             name='layer_1', 
                             activation='relu')
        output = model_layer(self.gw, 
                             output, 
                             self.gw.edge_feat['edge_feature'],
                             hidden_size=1, 
                             name='layer_2', 
                             activation=None)
                             
        # 对于二分类任务,可以使用以下 API 计算损失
        loss = fluid.layers.sigmoid_cross_entropy_with_logits(x=output, 
                                                              label=self.node_label)
        # 计算平均损失
        loss = fluid.layers.mean(loss)
        
        # 计算准确率
        prob = fluid.layers.sigmoid(output)
        pred = prob > 0.5
        pred = fluid.layers.cast(prob > 0.5, dtype="float32")
        correct = fluid.layers.equal(pred, self.node_label)
        correct = fluid.layers.cast(correct, dtype="float32")
        acc = fluid.layers.reduce_mean(correct)

        return loss, acc

至此,一个基于paddle及pgl的简单GCN网络搭建完成。

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值