图神经网络改进-手把手教你改代码-第1期:图神经网络代码快速上手指南

  本系列项目主攻:代码分享与讲解创新思路解析前沿模块缝合二次创新实现方法
项目主要提供关于:

  • 图神经网络
  • 图对比学习
  • 图结构学习
  • 超图神经网络
  • 超图对比学习
  • 超图结构学习

  这六种方向的通用模型、原创代码以及改进思路,供大家参考学习,后续还会持续更新针对链路预测、节点分类等下游任务上的代码以及改进思路,帮助大家提升代码水平,多发论文。


 希望可以帮助大家快速上手实践图神经网络,实践是最好的入门方式!

  祝大家论文顺利,accept冲冲冲!

第一期:图神经网络代码快速上手指南


详细讲解视频:【图神经网络改进-手把手教你改代码-第1期】

项目Github:图小狮


以下为具体的代码部分:

1.常用超参数设计,其具体作用设置经验讲解见讲解视频

# 命令行参数
def parse_arguments():
    parser = argparse.ArgumentParser(description='Base-Graph Neural Network')
    parser.add_argument('--dataset', choices=['Cora', 'Citeseer', 'Pubmed'], default='Cora',
                        help="Dataset selection")
    parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden layer dimension')
    parser.add_argument('--dropout_rate', type=float, default=0.5, help='Dropout rate')
    parser.add_argument('--model', choices=['GCN', 'GAT', 'SAGE', 'ChebNet', 'TransformerConv'], default='TransformerConv',
                        help="Model selection")
    parser.add_argument('--lr', default=0.01, help="Learning Rate selection")
    parser.add_argument('--wd', default=5e-4, help="weight_decay selection")
    parser.add_argument('--epochs', default=200, help="train epochs selection")
    parser.add_argument('--tsne_drawing', choices=[True, False], default=True,
                        help="Whether to use tsne drawing")
    parser.add_argument('--tsne_colors', default=['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'], help="colors")
    return parser.parse_args()

2.数据集加载方法

# 加载数据集
def load_dataset(name):
    dataset = Planetoid(root='dataset/' + name, name=name, transform=T.NormalizeFeatures())
    return dataset

3.绘图函数

# 使用Tsne绘图
def plot_points(z, y):
    z = TSNE(n_components=2).fit_transform(z.detach().numpy())
    classes = len(torch.unique(y))
    y = y.cpu().numpy()
    plt.figure(figsize=(8, 8))
    for i in range(classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=args.tsne_colors[i])
    plt.axis('off')
    plt.savefig('{} embeddings ues tnse to plt figure.png'.format(args.model))
    plt.show()

4.模型定义

# 定义模型
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, model_type, dropout_rate):
        super(GNN, self).__init__()
        self.dropout_rate = dropout_rate

        if model_type == 'GCN':
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        elif model_type == 'GAT':
            self.conv1 = GATConv(in_channels, hidden_channels)
            self.conv2 = GATConv(hidden_channels, out_channels)
        elif model_type == 'SAGE':
            self.conv1 = SAGEConv(in_channels, hidden_channels)
            self.conv2 = SAGEConv(hidden_channels, out_channels)
        elif model_type == 'ChebNet':
            self.conv1 = ChebConv(in_channels, hidden_channels, K=2)
            self.conv2 = ChebConv(hidden_channels, out_channels, K=2)
        elif model_type == 'TransformerConv':
            self.conv1 = TransformerConv(in_channels, hidden_channels)
            self.conv2 = TransformerConv(hidden_channels, out_channels)
        else:
            raise NotImplementedError

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = torch.nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv2(x, edge_index)
        return x

5.训练函数与测试函数

def train(model, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()
# 测试函数
def test(model, data):
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs, logits

6.运行

if __name__ == "__main__":
    args = parse_arguments()
    dataset = load_dataset(args.dataset)
    data = dataset[0]
    model = GNN(in_channels=dataset.num_node_features, hidden_channels=args.hidden_dim,
                out_channels=dataset.num_classes, model_type=args.model, dropout_rate=args.dropout_rate)
    print(model)
    print(f"Loaded {args.dataset} dataset with {data.num_nodes} nodes and {data.num_edges} edges.")
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    criterion = torch.nn.CrossEntropyLoss()
    Best_Acc = []
    for epoch in range(1, args.epochs):
        loss = train(model, data)
        accs, log= test(model, data)
        train_acc, val_acc, test_acc = accs
        print(f'Epoch: [{epoch:03d}/200], Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
        Best_Acc.append(test_acc)
    if args.tsne_drawing == True:
        plot_points(log, data.y)
    print('---------------------------')
    print('Best Acc: {:.4f}'.format(test_acc))
    print('---------------------------')


Powered By 图小狮

希望能够得到大家的喜欢,您的点赞收藏即是对我们最大的支持!


 Bug fix:
print('Best Acc: {:.4f}'.format(test_acc))   --->   print('Best Acc: {:.4f}'.format(max(Best_Acc)))
  • 14
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
完成端口是一种异步输入/输出(I/O)模型,在网络编程中起到了重要的作用。它是Windows操作系统中提供的一种高效的I/O完成机制。 完成端口的工作方式是通过一个预先创建的I/O完成端口对象来管理I/O操作。在应用程序中,可以创建多个完成端口对象,用于不同的I/O操作。完成端口对象会与一个执行线程相关联,这个线程会在I/O操作完成时被唤醒。当一个I/O操作完成时,操作系统会将完成的消息发送给完成端口对象,并唤醒相应的线程。 使用完成端口的好处是可以实现高效的并发I/O操作。通过使用线程池,可以有效地处理多个客户端请求,并且不会因为等待I/O操作而造成线程的闲置。此外,完成端口还可以用于实现高性能的服务器应用程序,因为它能够轻松地处理大量的并发I/O操作。 完成端口的使用步骤如下: 1. 创建完成端口对象,并绑定执行线程。 2. 创建一个I/O请求(例如读取或写入操作)。 3. 将I/O请求与完成端口对象关联。 4. 执行I/O操作,并等待操作完成。 5. 当操作完成时,线程被唤醒,并处理完成的I/O请求。 总之,完成端口是一种强大而高效的I/O完成机制,在网络编程中非常实用。它能够提供高并发的I/O操作能力,使得应用程序能够高效地处理多个客户端请求。通过合理地利用完成端口,在网络编程中可以实现高性能和高效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值