获取Cora数据集,查看训练、测试、验证集比例
from torch_geometric.datasets import Planetoid
# '''
# 下载报错,将所有data文件下载到本地
# https://github.com/kimiyoung/planetoid
# 将cora相关文件放入到raw文件中
# '''
dataset = Planetoid(root='./tmp/Cora',name='Cora')
print((dataset[0].train_mask).sum())
print((dataset[0].test_mask).sum())
print((dataset[0].val_mask).sum())
print(dataset[0])
获取模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
构建GCN网络,训练 验证
class GCN_Net(torch.nn.Module):
def __init__(self, features, hidden, classes):
super(GCN_Net, self).__init__()
self.conv1 = GCNConv(features, hidden)
self.conv2 = GCNConv(hidden, classes)
def forward(self, data):
x, edge_index = data.x, data.