简介:本文实现了一个GAT图注意力机制的网络层,可以在Keras中像调用Dense网络层、Input网络层一样直接搭积木进行网络组合。
一,基本展示
如下图所示,我们输入邻接矩阵和节点特征矩阵之后,可以直接调用myGraphAttention网络层得到每一头的注意力输出(节点emdbeding),十分的方便。
注意:上图有个BUG,最终的输出层应该是8,和输入节点特征保持一致,上图只是举一个例子。
二,代码实现
2.1 GAT网络层
GAT网络层的代码如下。
from __future__ import absolute_import
from keras.activations import relu
from keras import activations, constraints, initializers, regularizers
from keras import backend as K
from keras.layers import Layer, Dropout, LeakyReLU
# 定义图卷积层
class myGraphAttention(Layer):
def __init__(self,
F_,
activation='relu',
use_bias=True,
drop_rate = 0,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
attn_kernel_initializer='glorot_uniform',
kernel_regularizer=None,
bias_regularizer=None,
attn_kernel_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
attn_kernel_constraint=None,
**kwargs):
self.F_ = F_ # 输出的节点embeding维度
self.activation = activations.get(activation) # 输出结果之前的激活函数
self.use_bias = use_bias
"""
其他代码………………
"""
super(myGraphAttention, self).__init__(**kwargs)
def build(self, input_shape):
"""
其他代码………………
"""
def call(self, inputs):
X = inputs[0] # 节点特征 (N x F)
A = inputs[1] # 邻接矩阵 (N x N)
"""
其他代码………………
"""
# 加上偏置
if self.use_bias:
node_features = K.bias_add(node_features, self.bias)
# 最终的输出之前得激活一下
output = self.activation(node_features)
return output
def compute_output_shape(self, input_shape):
output_shape = input_shape[0][0], self.output_dim
return output_shape
2.2 模型搭建
模型搭建的代码如下。
from keras.layers import Layer,Input,Dense,add,Lambda
from keras.models import Model
inp_adj_martrix = Input(shape=(5,5),name='adj_martrix')
inp_node_features = Input(shape=(5,8),name='node_features_martrix')
# 在这里直接调用网络层
flat0 = myGraphAttention(12,name="head_0")([inp_node_features,inp_adj_martrix])
flat1 = myGraphAttention(12,name="head_1")([inp_node_features,inp_adj_martrix])
flat2 = myGraphAttention(12,name="head_2")([inp_node_features,inp_adj_martrix])
flat = add([flat0,flat1,flat2])
lorder = 1
one_node = Lambda(lambda inp: inp[:,0,:],name = "the-first-node-feature")(flat)
o1 = Dense(32,activation="relu")(one_node)
o2 = Dense(32,activation="relu")(o1)
out = Dense(12)(o2)
model = Model([inp_adj_martrix,inp_node_features],[out])
创作不易,需要完整代码4_liao我哦。还有很多预测网络结构。