本系列项目主攻:代码分享与讲解、创新思路解析、前沿模块缝合及二次创新实现方法。
项目主要提供关于:
- 图神经网络
- 图对比学习
- 图结构学习
- 超图神经网络
- 超图对比学习
- 超图结构学习
这六种方向的通用模型、原创代码以及改进思路,供大家参考学习,后续还会持续更新针对链路预测、节点分类等下游任务上的代码以及改进思路,帮助大家提升代码水平,多发论文。
希望可以帮助大家快速上手实践图神经网络,实践是最好的入门方式!
祝大家论文顺利,accept冲冲冲!
第一期:图神经网络代码快速上手指南
详细讲解视频:【图神经网络改进-手把手教你改代码-第1期】
项目Github:图小狮
以下为具体的代码部分:
1.常用超参数设计,其具体作用与设置经验讲解见讲解视频
# 命令行参数
def parse_arguments():
parser = argparse.ArgumentParser(description='Base-Graph Neural Network')
parser.add_argument('--dataset', choices=['Cora', 'Citeseer', 'Pubmed'], default='Cora',
help="Dataset selection")
parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden layer dimension')
parser.add_argument('--dropout_rate', type=float, default=0.5, help='Dropout rate')
parser.add_argument('--model', choices=['GCN', 'GAT', 'SAGE', 'ChebNet', 'TransformerConv'], default='TransformerConv',
help="Model selection")
parser.add_argument('--lr', default=0.01, help="Learning Rate selection")
parser.add_argument('--wd', default=5e-4, help="weight_decay selection")
parser.add_argument('--epochs', default=200, help="train epochs selection")
parser.add_argument('--tsne_drawing', choices=[True, False], default=True,
help="Whether to use tsne drawing")
parser.add_argument('--tsne_colors', default=['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'], help="colors")
return parser.parse_args()
2.数据集加载方法
# 加载数据集
def load_dataset(name):
dataset = Planetoid(root='dataset/' + name, name=name, transform=T.NormalizeFeatures())
return dataset
3.绘图函数
# 使用Tsne绘图
def plot_points(z, y):
z = TSNE(n_components=2).fit_transform(z.detach().numpy())
classes = len(torch.unique(y))
y = y.cpu().numpy()
plt.figure(figsize=(8, 8))
for i in range(classes):
plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=args.tsne_colors[i])
plt.axis('off')
plt.savefig('{} embeddings ues tnse to plt figure.png'.format(args.model))
plt.show()
4.模型定义
# 定义模型
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, model_type, dropout_rate):
super(GNN, self).__init__()
self.dropout_rate = dropout_rate
if model_type == 'GCN':
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
elif model_type == 'GAT':
self.conv1 = GATConv(in_channels, hidden_channels)
self.conv2 = GATConv(hidden_channels, out_channels)
elif model_type == 'SAGE':
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
elif model_type == 'ChebNet':
self.conv1 = ChebConv(in_channels, hidden_channels, K=2)
self.conv2 = ChebConv(hidden_channels, out_channels, K=2)
elif model_type == 'TransformerConv':
self.conv1 = TransformerConv(in_channels, hidden_channels)
self.conv2 = TransformerConv(hidden_channels, out_channels)
else:
raise NotImplementedError
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = torch.nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv2(x, edge_index)
return x
5.训练函数与测试函数
def train(model, data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
# 测试函数
def test(model, data):
model.eval()
logits, accs = model(data.x, data.edge_index), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs, logits
6.运行
if __name__ == "__main__":
args = parse_arguments()
dataset = load_dataset(args.dataset)
data = dataset[0]
model = GNN(in_channels=dataset.num_node_features, hidden_channels=args.hidden_dim,
out_channels=dataset.num_classes, model_type=args.model, dropout_rate=args.dropout_rate)
print(model)
print(f"Loaded {args.dataset} dataset with {data.num_nodes} nodes and {data.num_edges} edges.")
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
criterion = torch.nn.CrossEntropyLoss()
Best_Acc = []
for epoch in range(1, args.epochs):
loss = train(model, data)
accs, log= test(model, data)
train_acc, val_acc, test_acc = accs
print(f'Epoch: [{epoch:03d}/200], Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
Best_Acc.append(test_acc)
if args.tsne_drawing == True:
plot_points(log, data.y)
print('---------------------------')
print('Best Acc: {:.4f}'.format(test_acc))
print('---------------------------')
Powered By 图小狮
希望能够得到大家的喜欢,您的点赞收藏即是对我们最大的支持!
Bug fix:
print('Best Acc: {:.4f}'.format(test_acc)) ---> print('Best Acc: {:.4f}'.format(max(Best_Acc)))