GNN从入门到放弃

该博客介绍了如何利用图同构网络(GIN)实现图表征网络,包括节点嵌入、图池化和图预测。首先,通过GINNodeEmbedding模块对每个节点进行嵌入,然后采用不同类型的图池化方法(如加法、平均、最大值或注意力机制)聚合节点表征得到图的表示。最后,通过线性变换对图的表示进行预测。此外,还详细展示了GINNodeEmbedding模块的工作流程,以及GINConv层如何处理节点和边的特征信息。
摘要由CSDN通过智能技术生成

task06

基于图同构网络(GIN)的图表征网络的实现

图表征模块运行流程:

(1)用GINNodeEmbedding模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;
(2)对节点表征做图池化得到图的表征;
(3)用一层线性变换对图表征转换为对图的预测。

代码:

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding

class GINGraphRepr(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """GIN Graph Pooling Module
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
            num_tasks为图表征维度,默认1
            num_layers (int, optional): GINConv层数,默认5
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            emb_dim为节点表征的维度,默认为300
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
            graph_pooling为图节点表征的池化方法

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
		
		# 用GINNodeEmbedding模块对图上每一个节点做节点嵌入,得到节点表征
        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
		
		# 对节点表征做图池化得到图的表征
        # Pooling function to generate whole-graph embeddings
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因为预测目标的取值范围就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)

基于图同构网络的节点嵌入模块

(1)用AtomEncoder对其做嵌入得到第0层节点表征。
(2)然后逐层计算节点表征,从第1层开始到第num_layers层,每一层节点表征的计算都以上一层的节点表征h_list[layer]、边edge_index和边的属性edge_attr为输入。
 

import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation

GINConv–图同构卷积层

self.propagate()开始执行,message()中,张量分为目标节点和源节点,定义了从源节点到目标节点的消息,update()执行消息更新,forward()实现消息更新。

AtomEncoder 与 BondEncoder

由于在当前的例子中,节点(原子)和边(化学键)的属性都为离散值,它们属于不同的空间,无法直接将它们融合在一起。通过嵌入(Embedding),我们可以将节点属性和边属性分别映射到一个新的空间,在这个新的空间中,我们就可以对节点和边进行信息融合。在GINConv中,message()函数中的x_j + edge_attr 操作执行了节点信息和边信息的融合。

作业

请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值