GAT in DGL
GAT引入注意力机制来代替静态的归一化卷积操作。
from dgl.nn.pytorch import GATConv
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
Multi-head attention
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
Put everything together
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask)
g = DGLGraph(data.graph)
return g, features, labels, mask
import time
import numpy as np
g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# main loop
dur = []
for epoch in range(30):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
Output:
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9449 | Time(s) nan
Epoch 00001 | Loss 1.9425 | Time(s) nan
Epoch 00002 | Loss 1.9402 | Time(s) nan
Epoch 00003 | Loss 1.9378 | Time(s) 0.2313
Epoch 00004 | Loss 1.9355 | Time(s) 0.2302
Epoch 00005 | Loss 1.9331 | Time(s) 0.2306
Epoch 00006 | Loss 1.9307 | Time(s) 0.2308
Epoch 00007 | Loss 1.9283 | Time(s) 0.2318
Epoch 00008 | Loss 1.9260 | Time(s) 0.2325
Epoch 00009 | Loss 1.9236 | Time(s) 0.2334
Epoch 00010 | Loss 1.9212 | Time(s) 0.2338
Epoch 00011 | Loss 1.9188 | Time(s) 0.2339
Epoch 00012 | Loss 1.9164 | Time(s) 0.2340
Epoch 00013 | Loss 1.9139 | Time(s) 0.2351
Epoch 00014 | Loss 1.9115 | Time(s) 0.2355
Epoch 00015 | Loss 1.9091 | Time(s) 0.2352
Epoch 00016 | Loss 1.9066 | Time(s) 0.2352
Epoch 00017 | Loss 1.9042 | Time(s) 0.2352
Epoch 00018 | Loss 1.9017 | Time(s) 0.2350
Epoch 00019 | Loss 1.8992 | Time(s) 0.2346
Epoch 00020 | Loss 1.8967 | Time(s) 0.2346
Epoch 00021 | Loss 1.8942 | Time(s) 0.2345
Epoch 00022 | Loss 1.8917 | Time(s) 0.2345
Epoch 00023 | Loss 1.8892 | Time(s) 0.2345
Epoch 00024 | Loss 1.8867 | Time(s) 0.2344
Epoch 00025 | Loss 1.8841 | Time(s) 0.2343
Epoch 00026 | Loss 1.8816 | Time(s) 0.2341
Epoch 00027 | Loss 1.8790 | Time(s) 0.2341
Epoch 00028 | Loss 1.8764 | Time(s) 0.2340
Epoch 00029 | Loss 1.8738 | Time(s) 0.2339