Graph Convolution Network图卷积网络(一)训练运行与代码概览

背景:看懂并运行Graph Convolution Network的pytorch代码。

代码地址https://github.com/tkipf/pygcn

论文地址https://arxiv.org/abs/1609.02907 Semi-Supervised Classification with Graph Convolutional Networks,ICLR 2017

目录

一、运行

1.1 搭建环境

1.2 运行

二、数据集概览

2.1 总览

2.2 content file

2.3 cites file

三、代码概览

2.1 train.py

2.2 models.py

2.3 layers.py

2.4 utils.py


一、运行

1.1 搭建环境

搭建python、torch及显卡环境,

CentOS 6.3安装anaconda并配置pytorch与cuda

1.2 运行

运行很简单,直接环境搭建好直接就可运行,并且没有复杂的命令行。

python setup.py install

python train.py

迭代200次,很快即可运行完毕。

Epoch: 0198 loss_train: 0.4363 acc_train: 0.9571 loss_val: 0.7047 acc_val: 0.8233 time: 0.0076s
Epoch: 0199 loss_train: 0.3899 acc_train: 0.9500 loss_val: 0.7030 acc_val: 0.8233 time: 0.0076s
Epoch: 0200 loss_train: 0.4644 acc_train: 0.9071 loss_val: 0.7015 acc_val: 0.8200 time: 0.0076s
Optimization Finished!
Total time elapsed: 2.2366s
Test set results: loss= 0.7510 accuracy= 0.8320

二、数据集概览

core数据集  https://github.com/tkipf/pygcn/tree/master/data/cora

2.1 总览

该数据集就是图结构的数据集。机器学习的paper,表示用paper ID,共2708个paper。包含了下面这些类:

		Case_Based
		Genetic_Algorithms
		Neural_Networks
		Probabilistic_Methods
		Reinforcement_Learning
		Rule_Learning
		Theory

共2708个paper,7个类,一共有1433个关键词,关键词频次小于10会被删掉。

2.2 content file

包含下面:

<paper_id> <word_attributes>+ <class_label>

最开始为paper_id,后面为每一个单词是否出现,最后为类的标签。

第一列为paper ID,后面几列为每个单词出现与否,用0与1表示,最后一列为类别标签。

 

2.3 cites file

表示论文之间的引用关系(由此看来这个图是有向图)

<ID of cited paper> <ID of citing paper>

前面为被引论文的ID,后面为引用前面的论文的ID

887	6215
887	64519
887	87363
887	976334
906	1103979
906	1105344
906	1114352
906	1136397

三、代码概览

看到论文中有大量公式及推导,但是实际代码量很少。相当于只有四个代码有实际语义层面的信息。

2.1 train.py

包含模型训练等等的信息。

加载参数

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

生成随机种子

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

加载数据

# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()

定义模型与optimizer

# Model and optimizer
model = GCN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            dropout=args.dropout)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

数据写入cuda,便于后续加速

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

定义训练函数

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output = model(features, adj)

    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'time: {:.4f}s'.format(time.time() - t))

定义测试函数,相当于对已有的模型再测试集上运行对应的loss与accuracy

逐个epoch进行train,最后test

# Train model
t_total = time.time()
for epoch in range(args.epochs):
    train(epoch)
print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Testing
test()

2.2 models.py

定义了模型GCN

2.3 layers.py

定义了模型如何卷积

2.4 utils.py

定义了加载数据等工具性的函数,这几个函数后面再详细解析。

Graph Convolution Network图卷积网络(二)数据加载与网络结构定义

  • 43
    点赞
  • 206
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 21
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值