图神经网络改进-手把手教你改代码-第4期:GCN、GAT手动实现

  本系列项目主攻:代码分享与讲解创新思路解析前沿模块缝合二次创新实现方法
项目主要提供关于:

  • 图神经网络
  • 图对比学习
  • 图结构学习
  • 超图神经网络
  • 超图对比学习
  • 超图结构学习

  这六种方向的通用模型、原创代码以及改进思路,供大家参考学习,后续还会持续更新针对链路预测、节点分类等下游任务上的代码以及改进思路,帮助大家提升代码水平,多发论文。


 希望可以帮助大家快速上手实践图神经网络,实践是最好的入门方式!

  祝大家论文顺利,accept冲冲冲!

第4期:手动实现GCN、GAT模块


本期为大家带来了图神经网络基础模块:GCN、GAT的复现代码,并附有逐行注释和视频讲解


Q:为什么要自己实现这些图神经网络基础模块?
A:PYG虽然提供了许多基础模块供直接调用,但是不利于基础创新,或不方便获取这些模块的中间输出等。


详细讲解视频:【图神经网络改进-手把手教你改代码-第4期】

项目Github:图小狮


以下为具体的代码部分:

GCN类


1.类继承自nn.Module

class GCNConv(nn.Module):

2.__init__方法

def __init__(self, in_features, out_features):
        super(GCNConv, self).__init__()
        # 输入特征的维度
        self.in_features = in_features
        # 输出特征的维度
        self.out_features = out_features
        # 定义可学习的权重矩阵
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        # 定义可学习的偏置向量
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        # 初始化参数
        self.reset_parameters()

3.参数初始化

def reset_parameters(self):
        # 使用Xavier初始化权重矩阵
        nn.init.xavier_uniform_(self.weight)
        # 将偏置向量初始化为0
        nn.init.zeros_(self.bias)

4.forward方法

def forward(self, x, edge_index):
        # x: 节点特征矩阵,大小为[num_nodes, in_features]
        # edge_index: 边索引,大小为[2, num_edges]

        # 计算规范化的邻接矩阵
        num_nodes = x.size(0)
        # 创建一个全零矩阵作为邻接矩阵的初始状态
        adj = torch.zeros(num_nodes, num_nodes).cuda()
        # 根据edge_index填充邻接矩阵,无向图因此两个方向都要填充
        adj[edge_index[0], edge_index[1]] = 1
        adj[edge_index[1], edge_index[0]] = 1
        # 计算每个节点的度
        deg = torch.sum(adj, dim=1)
        # 计算度矩阵的逆平方根,用于后续的归一化
        deg_inv_sqrt = 1.0 / torch.sqrt(deg)  # 加上一个小的常数避免除零错误
        # 计算对称归一化的邻接矩阵
        norm_adj = adj * deg_inv_sqrt.unsqueeze(1)
        norm_adj = norm_adj * deg_inv_sqrt.unsqueeze(0)

        # 支撑传播:线性变换节点特征
        support = torch.matmul(x, self.weight)  # X * W
        # 消息传递:通过归一化的邻接矩阵传播特征
        output = torch.matmul(norm_adj, support) + self.bias  # D^{-1/2} * A * D^{-1/2} * (X * W) + b
        return output  # 返回输出,可以选择使用ReLU等激活函数进行非线性变换

GAT类


1.类继承自nn.Module

class GCNConv(nn.Module):

2.__init__方法

def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2, concat=True):
        super(GATConv, self).__init__()
        # 定义dropout率,用于在注意力系数上进行dropout操作以防止过拟合
        self.dropout = dropout
        # 输入特征的维度
        self.in_features = in_features
        # 输出特征的维度
        self.out_features = out_features
        # LeakyReLU非线性激活函数中的负斜率alpha
        self.alpha = alpha
        # 是否在多头注意力中进行拼接,对于最后一层通常设为False,使用平均
        self.concat = concat
        # 定义可学习的权重矩阵W,用于线性变换输入特征
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        # 定义注意力机制中可学习的参数a
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        # 定义LeakyReLU激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)

3.注意力系数矩阵生成

def _prepare_attentional_mechanism_input(self, Wh):
        # 这个函数负责计算注意力系数
        # 首先通过与a的前半部分做矩阵乘法计算得到每个节点的影响力分数Wh1
        Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
        # 通过与a的后半部分做矩阵乘法计算得到每个节点被影响的分数Wh2
        Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
        # 将Wh1加上Wh2的转置,得到每一对节点的非归一化注意力分数e
        e = Wh1 + Wh2.T
        # 使用LeakyReLU激活函数处理e,增加非线性
        return self.leakyrelu(e)

4.forward方法

def forward(self, h, edge_index):
        # 将边索引转换为稠密邻接矩阵,并去除多余的维度
        adj = to_dense_adj(edge_index).squeeze(0)
        # 应用线性变换
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        # 准备注意力机制的输入
        e = self._prepare_attentional_mechanism_input(Wh)
        # 创建一个足够小的向量用于掩盖不存在的边
        zero_vec = -9e15*torch.ones_like(e)
        # 只有当adj中存在边时,才保留e中的值,否则用zero_vec中的极小值代替
        attention = torch.where(adj > 0, e, zero_vec)
        # 对注意力系数进行softmax操作,使得每个节点的注意力系数和为1
        attention = F.softmax(attention, dim=1)
        # 对注意力系数进行dropout
        attention = F.dropout(attention, self.dropout, training=self.training)
        # 应用注意力机制更新节点特征
        h_prime = torch.matmul(attention, Wh)

        # 如果concat为真,则对输出使用ELU激活函数;否则直接返回结果
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

Powered By 图小狮

希望能够得到大家的喜欢,您的点赞收藏即是对我们最大的支持!

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在Pytorch中实现基于GCN/GAT/Chebnet的交通流预测,可以参考以下步骤: 1. 数据预处理:读入交通流数据,构建交通网络,将节点和边转换为矩阵表示。 2. 模型定义:定义GCN/GAT/Chebnet神经网络模型,包括输入层、隐藏层、输出层等。 3. 模型训练:使用交通流数据进行模型训练,通过计算损失函数来优化模型参数。 4. 模型测试:使用测试集数据进行模型测试,预测交通流情况,计算预测值与实际值之间的误差。 下面是一个基于GCN的交通流预测模型的Pytorch代码示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class GCN(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(GCN, self).__init__() self.linear1 = nn.Linear(input_dim, hidden_dim) self.linear2 = nn.Linear(hidden_dim, output_dim) def forward(self, x, adj): x = F.relu(self.linear1(torch.matmul(adj, x))) x = self.linear2(torch.matmul(adj, x)) return x ``` 该模型包括两个线性层,其中第一个线性层将输入节点特征乘以邻接矩阵,然后通过ReLU激活函数得到隐藏层的输出,第二个线性层将隐藏层的输出再次乘以邻接矩阵得到最终的输出。 在训练过程中,需要定义损失函数和优化器,如下所示: ```python criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ``` 然后,使用交通流数据进行模型训练,如下所示: ```python for epoch in range(num_epochs): outputs = model(features, adj) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item())) ``` 在模型测试阶段,可以直接使用模型进行预测,如下所示: ```python with torch.no_grad(): test_outputs = model(test_features, adj) test_loss = criterion(test_outputs, test_labels) print('Test Loss: {:.4f}'.format(test_loss.item())) ``` 以上是基于GCN的交通流预测模型的Pytorch代码示例,类似的代码可以用于实现基于GAT/Chebnet的交通流预测模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值