pytorch下GCN代码解读
ps:仅个人学习用,欢迎讨论
代码来源:https://github.com/rusty1s/pytorch_geometric/blob/master/examples/gcn.py
def main():
print("hello world")
main()
import os.path as osp
import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv # noqa
#GCN用于提取图的特征参数然后用于分类
#数据集初始化部分
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
args = parser.parse_args()#是否使用GDC优化
dataset = 'CiteSeer'#训练用的数据集
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)#数据集存放位置
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())#数据初始化类,其dataset的基类是一个torch.utils.data.Dataset对象
data = dataset[0]#只有一个图作为训练数据