复现WL-DeepGCN代码过程及遇到的问题

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

复现WL-DeepGCN代码过程及遇到的问题


前言

github上代码不完整,对比了EV-GCN的代码,发现获取数据的代码类似,就添加到代码中
WL-DeepGCN代码链接


一、环境配置

代码中readme中详细讲了环境的具体配置

Python 3.7.12

Pytorch 1.13.0

Cuda 11.6

Numpy 1.21.6

scikit-learn 1.0.2

Nilearn 0.9.2

这里我配置的torch和cuda版本和文中一样

在这里插入图片描述
然后对应版本安装torch_geometric

在这里插入图片描述
然后还需要安装这些模块matplotlib,torchmetrics

遇到的问题1:

/root/miniconda3/envs/py37/lib/python3.7/site-packages/torch_geometric/typing.py:31: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: libpython3.7m.so.1.0: cannot open shared object file: No such file or directory
  warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
/root/miniconda3/envs/py37/lib/python3.7/site-packages/torch_geometric/typing.py:42: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: libpython3.7m.so.1.0: cannot open shared object file: No such file or directory
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "

这个警告消息表明在导入 torch-sparse 库时出现了问题,具体原因是系统无法找到 libpython3.7m.so.1.0 共享库文件。这个文件通常是 Python 3.7 的一个共享库。如果你使用的 Python 版本是 3.8 或更高版本,或者该共享库丢失了,就会出现此类问题。

解决方法:为了确实需要使用 Python 3.7 并且缺少 libpython3.7m.so.1.0,所以执行下面的代码解决

sudo apt-get update
sudo apt-get install libpython3.7

二、数据获取与处理

使用使用EV-GCN的fetch_data.py更改路径后,终端运行代码

python fetch_data.py

遇到问题1:从你提供的代码来看,features 变量是在每次交叉验证的循环中定义的。错误提示 NameError: name ‘features’ is not defined 表明在某个地方,features 变量在使用之前没有被正确定义。

File "Nested_CV.py", line 131, in <module>
    output, edge_weights = model(features, edge_index, edgenet_input)
NameError: name 'features' is not defined

解决方法:需要确保在所有情况下都定义 features 变量,无论 args.cuda 的值如何。你可以将 features 的定义移到 if args.cuda: 条件语句之外

for j in range(args.folds):
    print(' Starting the {}-{} Fold::'.format(i+1,j+1))
    node_ftr = dataloader.get_node_features(train_ind)
    edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
    edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
    
    model = GCN(input_dim = args.num_features,
                nhid = args.hidden, 
                num_classes = 2, 
                ngl = args.ngl, 
                dropout = args.dropout, 
                edge_dropout = args.edropout, 
                edgenet_input_dim = 2*nonimg.shape[1])
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    features = torch.tensor(node_ftr, dtype=torch.float32)
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.long)
    
    if args.cuda:
        model.cuda()
        features = features.cuda()
        edge_index = edge_index.cuda()
        edgenet_input = edgenet_input.cuda()
        labels = labels.cuda()
        fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)

遇到问题2:

Traceback (most recent call last):
  File "Nested_CV.py", line 146, in <module>
    acc_train = torchmetrics_accuracy(output[train_ind], labels[train_ind])
  File "/root/autodl-tmp/WL-DeepGCN-main/metrics.py", line 11, in torchmetrics_accuracy
    acc = torchmetrics.functional.accuracy(preds, labels)
TypeError: accuracy() missing 1 required positional argument: 'task'

解决方法:这个函数计算预测值的准确度。你可能需要指定任务类型,比如二分类或多分类,以确保函数能够正确处理。torchmetrics 的 accuracy 函数需要 task 参数,如果你处理的是二分类任务,可以修改为:

def torchmetrics_accuracy(preds, labels):
    acc = torchmetrics.functional.accuracy(preds, labels, task='multiclass', num_classes=2)
    return acc

三、代码运行

运行代码来进行测试和训练

python Nested_CV.py --train=1

值得高兴的是没有报错,但是结果不太好,ACC才60,和文中差距很大
然后调整了一下参数设置

在这里插入图片描述
最后调整了lr:0.01,epoch:300;去掉了早停,因为:早停可能会在模型还可以进一步改进的情况下提前终止训练。特别是当验证集性能指标偶尔出现波动时,早停可能会误判为性能停滞。

最后实验结果为请添加图片描述
结果还不算太差
如果还有什么进展,我再更新
收工!


总结

以上就是今天要讲的内容,本文仅仅简单介绍了复现DeepGCN代码过程及遇到的问题,还有其他问题欢迎交流~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值