GCN学习:用PyG实现自定义layers的GCN网络及训练(五)

本文详细介绍了如何使用PyG构建自定义的GCN网络,包括自定义layer的传播方式、GCN原理的节点解读,以及逐行代码解析。通过添加自环、节点特征线性变换、规范化等步骤,理解图卷积层中节点特征的聚合过程。
摘要由CSDN通过智能技术生成


目前的代码讲解基本都是直接使用PyG内置的包实现固定结构的网络层。虽然我们可以通过每层使用不同的传递方式来建立不同的网络,但是却不能自定义网络层的传递方式,对于做创新性的研究工作而言是一个不足。
本篇在 GCN学习:Pytorch-Geometric教程(二)的基础上,自定义了GCN的层传递方式(仍然是按照论文中的传递方式建立,但是我们以后也可以建立其他传递方式),其他代码与系列(二)的代码相同。

完整代码

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F

class GCNConv(MessagePassing):
    def __init__(self,in_channels,out_channels):
        super(GCNConv,self).__init__(aggr='add')
        self.lin=torch.nn.Linear(in_channels,out_channels)
    def forward(self,x,edge_index):
        edge_index, _ = add_self_loops(edge_index,num_nodes=x.size(0))
        x=self.lin(x)
        row,col=edge_index
        #计算度矩阵
        deg=degree(col,x.size(0),dtype=x.dtype)
        #度矩阵的-1/2次方
        deg_inv_sqrt=deg.pow(-0.5)
        norm=deg_inv_sqrt[row]*deg_inv_sqrt[col]
        return self.propagate(edge_index,x=x,norm=norm)
    def message(self,x_j,norm):
        return norm.view(-1,1)*x_j



ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
print(dataset)
print(dataset.num_node_features)
print(dataset.num_classes)
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值