目录
-
不使用GPU
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
def message_func(self, edges):
return {'z' : edges.src['z'], 'e' : edges.data['e']}
def reduce_func(self, nodes):
alpha = F.softmax(nodes.mailbox['e'], dim=1)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
def forward(self, h):
z = self.fc(h) # eq. 1
self.g.ndata['z'] = z
self.g.apply_edges(self.edge_attention) # eq. 2
self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
return self.g.ndata.pop('h')
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='