GATNE代码部分讲解

GATNE介绍

本次要解读的GATNE模型,首次被提出在KDD2019的文章《Representation Learning for Attributed Multiplex Heterogeneous Network》,github地址为THUDM/GATNE,理论方面诸如公式、原理可以看《阿里电商场景下的大规模异构网络表示学习》,笔者在此主要从代码方面解读该模型,方便读者理解与实践。其中如有理解出现偏颇的地方,还请大家批评指正。

输入

首先我们关注模型的输入与输出,在github项目的readme.md中,将数据集分为train.txt/valid.txt/test.txt,GATNE的强大在于这个模型不仅可以解决同构网络,还可以解决异构网络问题。

文件名 格式
train.txt edge_type node1 node2
valid.txt edge_type node1 node2 label
test.txt edge_type node1 node2 label
node_type.txt node node_type 当解决同构网络时则不需要此文件
feature.txt node f_1 f_2 …f_dim

输出

输出测试集的ROC-AUC、F1和PR-AUC
也可以增加相应代码获得各个边对应下的节点向量

模型核心代码

这一部分,笔者在代码上标记有注释,主要点明各个张量的维度,代码只截取GATNE/src/runs/main.py的主要代码段。

def train_model(network_data, feature_dic, log_name):
    # 按照节点类型划分的数据点和不按照节点类型划分的数据点
    base_walks, all_walks = generate_walks(network_data)
    # 获得id_to_word word_to_id
    vocab, index2word = generate_vocab([base_walks])
    # 获得训练节点对
    train_pairs = generate_pairs(all_walks, vocab)
    # 获取边的类型
    edge_types = list(network_data.keys())
    # 获取节点的数量
    num_nodes = len(index2word)
    # 因为edge_types里面有base,所以需要减1
    edge_type_count = len(edge_types) - 1
    epochs = args.epoch
    batch_size = args.batch_size
    embedding_size = args.dimensions  # Dimension of the embedding vector.
    embedding_u_size = args.edge_dim  # 边的维度
    u_num = edge_type_count  # 边的类型数量
    num_sampled = args.negative_samples  # Number of negative examples to sample.
    dim_a = args.att_dim  # 默认为20
    att_head = 1
    neighbor_samples = args.neighbor_samples
    # 三维的列表,[网络中节点个数,边的类型数量,1]
    neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)]
    for r in range(edge_type_count):
        g &#
  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值