pytorch框架下—GCN代码详细解读

说明:本文是对论文“SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS, ICLR 2017”中描述的GCN模型代码的详细解读。
代码下载地址:https://github.com/tkipf/pygcn
论文下载地址:https://arxiv.org/abs/1609.02907
数据集下载地址:https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz


代码结构总览

作为图神经网络的入门级代码,该代码相对来说结构简单。代码总共包含四个部分:

此处的代码结构图参考了CSDN博客“Graph Convolution Network图卷积网络(一)训练运行与代码概览

在这里插入图片描述

  • utils:定义了加载数据等工具性的函数
  • layers:定义了模块如何计算卷积
  • models:定义了模型train
  • train:包含了模型训练信息
    在这里插入图片描述

一、数据集结构、内容分析

1、数据集结构

论文中所使用的数据集合是Cora数据集,总共有三部分构成:cora.content cora.cites 和README。

README: 对数据集内容的描述;

cora.content: 里面包含有每一篇论文各自独立的信息;

该文件总共包含2078行,每一行代表一篇论文,由论文编号、论文词向量(1433维)和论文的类别三个部分组成

cora.cites: 里面包含有各论文之间的相互引用记录

该文件总共包含5429行,每一行是两篇论文的编号,表示右边的论文引用左边的论文。

2、数据集内容分析

该数据集总共有2078个样本,而且每个样本都为一篇论文。根据README可知,所有的论文被分为了7个类别,分别为:

  1. 基于案列的论文
  2. 基于遗传算法的论文
  3. 基于神经网络的论文
  4. 基于概率方法的论文
  5. 基于强化学习的论文
  6. 基于规则学习的论文
  7. 理论描述类的论文

此外,为了区分论文的类别,使用一个1433维的词向量,对每一篇论文进行描述,该向量的每个元素都为一个词语是否在论文中出现,如果出现则为“1”,否则为“0”。

二、utils代码分析

1、代码总览

import numpy as np
import scipy.sparse as sp
import torch


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {
   c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def load_data(path="../data/cora/", dataset="cora"):
    """Load citation network dataset (cora only for now)"""
    print('Loading {} dataset...'.format(dataset))

    idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                        dtype=np.dtype(str))
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    labels = encode_onehot(idx_features_labels[:, -1])

    # build graph
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    idx_map = {
   j: i for i, j in enumerate(idx)}
    edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                    dtype=np.int32)
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)

    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    features = normalize(features)
    adj = normalize(adj + sp.eye(adj.shape[0]))

    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(np.where(labels)[1])
    adj = sparse_mx_to_torch_sparse_tensor(adj)

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    return adj, features, labels, idx_train, idx_val, idx_test


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels
  • 179
    点赞
  • 729
    收藏
    觉得还不错? 一键收藏
  • 45
    评论
PyTorch GCN代码是指使用PyTorch库实现图卷积网络(Graph Convolutional Networks,GCN)的代码。 图卷积网络是一种用于图数据的深度学习模型,主要用于节点分类、链接预测和图生成等任务。它通过对图结构进行卷积操作来提取节点的特征表示。而PyTorch是一种基于Python的开源深度学习框架,提供了丰富的神经网络模块和自动求导功能。 在使用PyTorch库实现GCN代码中,通常需要进行以下几个步骤: 1. 数据准备:需要将图数据转换为PyTorch可处理的数据格式,通常使用邻接矩阵和节点特征矩阵表示图结构和节点特征。 2. 模型定义:定义GCN模型的结构,通常包括多层图卷积层、激活函数和池化层等。每一层的输出作为下一层的输入,以逐层提取节点特征。 3. 模型训练:使用训练数据对定义的GCN模型进行训练,通常使用随机梯度下降(SGD)等优化算法来更新模型参数,以降低训练损失。 4. 模型评估:使用测试数据对训练好的模型进行评估,通常使用准确度、精确度、召回率等指标来评估模型的性能。 需要注意的是,代码的具体实现方式会因不同的GCN变体而有所差异,例如ChebNet、SpectralNet等。此外,代码中还可能包括数据预处理、结果可视化和超参数调优等过程。 总之,PyTorch GCN代码是指使用PyTorch库实现图卷积网络的代码,其实现过程涵盖数据准备、模型定义、模型训练和模型评估等步骤。具体实现方式会因GCN的变体而有所不同。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 45
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值