本文分享一个对Keras版GAT源码的分析。
GAT原文:https://arxiv.org/abs/1710.10903,建议参考着知乎superbrother大神的文章进行理解。
TensorFlow版可以看:GitHub - PetarV-/GAT: Graph Attention Networks (https://arxiv.org/abs/1710.10903)
源代码 github:GitHub - danielegrattarola/keras-gat: Keras implementation of the graph attention networks (GAT) by Veličković et al. (2017; https://arxiv.org/abs/1710.10903)
1 utils.py
utils.py定义数据的加载,预处理,与normalize adj矩阵
from __future__ import print_function
import os
import pickle as pkl
import sys
import networkx as nx
import numpy as np
import scipy.sparse as sp
def parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def sample_mask(idx, l):
"""Create mask."""
mask = np.zeros(l)
mask[idx] = 1
return np.array(mask, dtype=np.bool)
def load_data(dataset_str):
"""Load data."""
"""
Loads input data from gcn/data directory
ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
(a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object;
ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param dataset_str: Dataset name
:return: All data input files loaded (as well the training/test data).
"""
"""
从data文件中读取数据,文档中包含3组数据'cora', 'citeseer',与'pubmed',每组数据有8种类型。
'x'为训练数据的特征向量,
'tx'为测试数据的特征向量,
'allx','ally':整个graph上除test外的(包括train和val) 所有data的特征和便签,
'y'为训练数据的标签,'ty'为测试数据的标签,
'index'为测试数据的ID,
'graph'为图数据。
"""
#返回的是.py文件的绝对路径
FILE_PATH = os.path.abspath(__file__)
#返回的是.py文件的目录
DIR_PATH = os.path.dirname(FILE_PATH)
#到这里返回了data文件夹的绝对路径
DATA_PATH = os.path.join(DIR_PATH, 'data/')
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(names)):
with open("{}ind.{}.{}".format(DATA_PATH, dataset_str, names[i]), 'rb') as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
# x.shape:(140, 1433); y.shape:(140, 7);tx.shape:(1000, 1433);ty.shape:(1708, 1433);
# allx.shape:(1708, 1433);ally.shape:(1708, 7)
x, y, tx, ty, allx, ally, graph = tuple(objects)
# 训练数据集
# print(x[0][0],x.shape,type(x)) ##x是一个稀疏矩阵,记住1的位置,140个实例,每个实例的特征向量维度是1433 (140,1433)
# print(y[0],y.shape) ##y是标签向量,7分类,140个实例 (140,7)
##测试数据集
# print(tx[0][0],tx.shape,type(tx)) ##tx是一个稀疏矩阵,1000个实例,每个实例的特征向量维度是1433 (1000,1433)
# print(ty[0],ty.shape) ##y是标签向量,7分类,1000个实例 (1000,7)
##allx,ally和上面的形式一致
# print(allx[0][0],allx.shape,type(allx)) ##tx是一个稀疏矩阵,1708个实例,每个实例的特征向量维度是1433 (1708,1433)
# print(ally[0],al