GNNs-for-Node-Classification 项目使用文档
1. 项目的目录结构及介绍
GNNs-for-Node-Classification/
├── data/
│ ├── processed/
│ └── raw/
├── models/
│ ├── gcn.py
│ ├── gat.py
│ └── graphsage.py
├── configs/
│ ├── default.yaml
│ └── custom.yaml
├── utils/
│ ├── preprocessing.py
│ └── metrics.py
├── main.py
├── requirements.txt
└── README.md
目录结构介绍
- data/: 存放数据文件,包括处理后的数据和原始数据。
- processed/: 处理后的数据文件。
- raw/: 原始数据文件。
- models/: 存放模型定义文件。
- gcn.py: 图卷积网络模型定义。
- gat.py: 图注意力网络模型定义。
- graphsage.py: GraphSAGE模型定义。
- configs/: 存放配置文件。
- default.yaml: 默认配置文件。
- custom.yaml: 自定义配置文件。
- utils/: 存放工具函数和辅助函数。
- preprocessing.py: 数据预处理函数。
- metrics.py: 评估指标函数。
- main.py: 项目启动文件。
- requirements.txt: 项目依赖文件。
- README.md: 项目说明文档。
2. 项目的启动文件介绍
main.py
main.py
是项目的启动文件,负责加载配置、数据预处理、模型训练和评估等任务。以下是 main.py
的主要功能模块:
import argparse
import yaml
from models import gcn, gat, graphsage
from utils import preprocessing, metrics
def main(config_path):
# 加载配置文件
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# 数据预处理
data = preprocessing.load_data(config['data_path'])
# 模型选择
if config['model'] == 'gcn':
model = gcn.GCN(config)
elif config['model'] == 'gat':
model = gat.GAT(config)
elif config['model'] == 'graphsage':
model = graphsage.GraphSAGE(config)
else:
raise ValueError("Unsupported model type")
# 模型训练
model.train(data)
# 模型评估
metrics.evaluate(model, data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Graph Neural Networks for Node Classification")
parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to configuration file")
args = parser.parse_args()
main(args.config)
启动命令
python main.py --config configs/default.yaml
3. 项目的配置文件介绍
default.yaml
default.yaml
是项目的默认配置文件,包含模型训练所需的各种参数。以下是 default.yaml
的一个示例:
data_path: "data/processed/cora.npz"
model: "gcn"
learning_rate: 0.01
num_epochs: 200
hidden_dim: 16
dropout: 0.5
weight_decay: 5e-4
配置文件参数介绍
- data_path: 数据文件路径。
- model: 选择使用的模型类型(gcn, gat, graphsage)。
- learning_rate: 学习率。
- num_epochs: 训练轮数。
- hidden_dim: 隐藏层维度。
- dropout: Dropout 比例。
- weight_decay: 权重衰减系数。
通过修改 default.yaml
文件中的参数,可以调整模型的训练行为和性能。