图神经网络改进-手把手教你改代码-第3.1期:图神经网络下游任务-节点分类

本文详细介绍了一个图神经网络模型,包括GCN、GAT、SAGE等不同类型的实现,以及训练和测试函数。重点讲解了节点分类任务的代码,并提供了GitHub链接和B站视频资源,旨在帮助读者快速上手并提升图神经网络技能。
摘要由CSDN通过智能技术生成

  本系列项目主攻:代码分享与讲解创新思路解析前沿模块缝合二次创新实现方法
项目主要提供关于:

  • 图神经网络
  • 图对比学习
  • 图结构学习
  • 超图神经网络
  • 超图对比学习
  • 超图结构学习

  这六种方向的通用模型、原创代码以及改进思路,供大家参考学习,后续还会持续更新针对链路预测、节点分类等下游任务上的代码以及改进思路,帮助大家提升代码水平,多发论文。


 希望可以帮助大家快速上手实践图神经网络,实践是最好的入门方式!

  祝大家论文顺利,accept冲冲冲!

第3.1期:图神经网络下游任务-节点分类


其实我们第一期的快速上手代码就是节点分类任务的代码,本期的讲解视频附有逐函数、逐行讲解!

详细讲解视频:【图神经网络改进-手把手教你改代码-第3期-1】

项目Github:图小狮


以下为代码的主要逻辑实现:

1.模型定义

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, model_type, dropout_rate):
        super(GNN, self).__init__()
        self.dropout_rate = dropout_rate

        if model_type == 'GCN':
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        elif model_type == 'GAT':
            self.conv1 = GATConv(in_channels, hidden_channels)
            self.conv2 = GATConv(hidden_channels, out_channels)
        elif model_type == 'SAGE':
            self.conv1 = SAGEConv(in_channels, hidden_channels)
            self.conv2 = SAGEConv(hidden_channels, out_channels)
        elif model_type == 'ChebNet':
            self.conv1 = ChebConv(in_channels, hidden_channels, K=2)
            self.conv2 = ChebConv(hidden_channels, out_channels, K=2)
        elif model_type == 'TransformerConv':
            self.conv1 = TransformerConv(in_channels, hidden_channels)
            self.conv2 = TransformerConv(hidden_channels, out_channels)
        else:
            raise NotImplementedError

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = torch.nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv2(x, edge_index)
        return x

2.训练与测试函数

def train(model, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()


def test(model, data):
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs, logits

3.运行模型

if __name__ == "__main__":
    args = parse_arguments()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = load_dataset(args.dataset)
    data = dataset[0].to(device)
    print(data)
    model = GNN(in_channels=dataset.num_node_features, hidden_channels=args.hidden_dim,
                out_channels=dataset.num_classes, model_type=args.model, dropout_rate=args.dropout_rate).to(device)
    print(model)
    print(f"Loaded {args.dataset} dataset with {data.num_nodes} nodes and {data.num_edges} edges.")
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    Best_Acc = []
    for epoch in range(1, args.epochs):
        loss = train(model, data)
        accs, log= test(model, data)
        train_acc, val_acc, test_acc = accs
        print(f'Epoch: [{epoch:03d}/200], Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
        Best_Acc.append(test_acc)
    if args.tsne_drawing == True:
        plot_points(log, data.y)
    print('---------------------------')
    print('Best Acc: {:.4f}'.format(test_acc))
    print('---------------------------')

Powered By 图小狮

希望能够得到大家的喜欢,您的点赞收藏即是对我们最大的支持!

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值