GNN学习笔记
GNN从入门到精通课程笔记
3.2 GraphSAGE (NIPS '17)
- Inductive Representation Learning on Large Graphs (NIPS '17)
Abstract
- Problem: existing approaches require that all nodes in the graph are present during training of the embeddings
- Solution: learn a function that generates embeddings by sampling and aggregating features from a node’s local neighborhood. (inductive framework )
Introduction
- In this work we both extend GCNs to the task of inductive unsupervised learning and propose a framework that generalizes the GCN approach to use trainable aggregation functions (beyond simple convolutions).
- GraphSAGE: sample and aggregate
- By incorporating node features in the learning algorithm, we simultaneously learn the topological structure of each node’s neighborhood as well as the distribution of node features in the neighborhood.
- Our algorithm can also be applied to graphs without node features.
Proposed method: GraphSAGE
-
Problem: how to aggregate feature information from a node’s local neighborhood
-
Embedding generation (forward propagation)
- Feature: x v x_v xv
- h N ( v ) k h_{\mathcal{N}(v)}^k hN(v)k: 第k个aggregator输出的v节点的v的邻居的representation的聚合结果
- h v t h_v^t hvt: 第k个aggregator输出的v节点的v的representation,是将第k-1个aggregator输出的自己的representation和邻居的representation连接起来,然后经过线性和非线性变化得到的。
- Neighborhood definition
- uniformly sample a fixed-size set of neighbors
-
Learning the parameters of GraphSAGE
- Unsupervised learning:Random walk
- Co-occurs nodes on the random walk (相近的节点): z v z_v zv, 越相似越好( z u T z v z_u^{T}z_v zuTzv越大)
- Negative nodes: z v n z_{v_n} zvn,越不像越好( − z u T z v n -z_u^{T}z_{v_{n}} −zuTzvn越大)
- Supervised learning: task-specific objective
- Unsupervised learning:Random walk
-
Aggregator Architectures: symmetric (invariant to permutations of its inputs)
-
Mean aggregator
-
LSTM aggregator: applying the LSTMs to a random permutation of the node’s neighbors
-
Pooling aggregator
-
Appendix: minibatch version of the forward prepagation
- 等价于在一个子图里运行algorithm1,这个子图是由minibatch节点的k-hop neighborhood组成的
- Implement
[1] https://github.com/twjiang/graphSAGE-pytorch.git
class SageLayer(nn.Module):
"""
Encodes a node's using 'convolutional' GraphSage approach
"""
def __init__(self, input_size, out_size, gcn=False):
super(SageLayer, self).__init__()
self.input_size = input_size
self.out_size = out_size
self.gcn = gcn
self.weight = nn.Parameter(torch.FloatTensor(out_size,
self.input_size if self.gcn else 2 * self.input_size))
# without gcn: concat(h_v^{k-1},h_{\mathcal{N}_v}^k) -> inputsize * 2
self.init_params()
def init_params(self):
for param in self.parameters():
nn.init.xavier_uniform_(param)
def forward(self, self_feats, aggregate_feats, neighs=None):
if not self.gcn:
combined = torch.cat([self_feats, aggregate_feats], dim=1)
else:
combined = aggregate_feats
combined = F.relu(self.weight.mm(combined.t())).t()
return combined
class GraphSage(nn.Module):
"""docstring for GraphSage"""
def __init__(self, num_layers, input_size, out_size, raw_features, adj_lists, device, gcn=False, agg_func='MEAN'):
super(GraphSage, self).__init__()
self.input_size = input_size
self.out_size = out_size
self.num_layers = num_layers
self.gcn = gcn
self.device = device
self.agg_func = agg_func
self.raw_features = raw_features
self.adj_lists = adj_lists
for index in range(1, num_layers+1):
layer_size = out_size if index != 1 else input_size
setattr(self, 'sage_layer'+str(index), SageLayer(layer_size, out_size, gcn=self.gcn))
def forward(self, nodes_batch):
"""
Generates embeddings for a batch of nodes.
nodes_batch -- batch of nodes to learn the embeddings
"""
lower_layer_nodes = list(nodes_batch)
nodes_batch_layers = [(lower_layer_nodes,)]
for i in range(self.num_layers):
lower_samp_neighs, lower_layer_nodes_dict, lower_layer_nodes= self._get_unique_neighs_list(lower_layer_nodes)
nodes_batch_layers.insert(0, (lower_layer_nodes, lower_samp_neighs, lower_layer_nodes_dict))
assert len(nodes_batch_layers) == self.num_layers + 1
pre_hidden_embs = self.raw_features
for index in range(1, self.num_layers+1):
nb = nodes_batch_layers[index][0]
pre_neighs = nodes_batch_layers[index-1]
aggregate_feats = self.aggregate(nb, pre_hidden_embs, pre_neighs)
sage_layer = getattr(self, 'sage_layer'+str(index))
if index > 1:
nb = self._nodes_map(nb, pre_hidden_embs, pre_neighs)
cur_hidden_embs = sage_layer(self_feats=pre_hidden_embs[nb],
aggregate_feats=aggregate_feats)
pre_hidden_embs = cur_hidden_embs
return pre_hidden_embs
def _nodes_map(self, nodes, hidden_embs, neighs):
layer_nodes, samp_neighs, layer_nodes_dict = neighs
assert len(samp_neighs) == len(nodes)
index = [layer_nodes_dict[x] for x in nodes]
return index
def _get_unique_neighs_list(self, nodes, num_sample=10):
_set = set
to_neighs = [self.adj_lists[int(node)] for node in nodes]
if not num_sample is None:
_sample = random.sample
samp_neighs = [_set(_sample(to_neigh, num_sample)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
else:
samp_neighs = to_neighs
samp_neighs = [samp_neigh | set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)]
_unique_nodes_list = list(set.union(*samp_neighs))
i = list(range(len(_unique_nodes_list)))
unique_nodes = dict(list(zip(_unique_nodes_list, i)))
return samp_neighs, unique_nodes, _unique_nodes_list
def aggregate(self, nodes, pre_hidden_embs, pre_neighs, num_sample=10):
unique_nodes_list, samp_neighs, unique_nodes = pre_neighs
assert len(nodes) == len(samp_neighs)
indicator = [(nodes[i] in samp_neighs[i]) for i in range(len(samp_neighs))]
assert (False not in indicator)
if not self.gcn:
samp_neighs = [(samp_neighs[i]-set([nodes[i]])) for i in range(len(samp_neighs))]
if len(pre_hidden_embs) == len(unique_nodes):
embed_matrix = pre_hidden_embs
else:
embed_matrix = pre_hidden_embs[torch.LongTensor(unique_nodes_list)]
mask = torch.zeros(len(samp_neighs), len(unique_nodes))
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1
# mean aggregation
num_neigh = mask.sum(1, keepdim=True)
mask = mask.div(num_neigh).to(embed_matrix.device)
aggregate_feats = mask.mm(embed_matrix)
return aggregate_feats
class UnsupervisedLoss(object):
"""docstring for UnsupervisedLoss"""
def __init__(self, adj_lists, train_nodes, device):
super(UnsupervisedLoss, self).__init__()
self.Q = 10
self.N_WALKS = 6
self.WALK_LEN = 1
self.N_WALK_LEN = 5
self.MARGIN = 3
self.adj_lists = adj_lists
self.train_nodes = train_nodes
self.device = device
self.target_nodes = None
self.positive_pairs = []
self.negtive_pairs = []
self.node_positive_pairs = {}
self.node_negtive_pairs = {}
self.unique_nodes_batch = []
def get_loss_sage(self, embeddings, nodes):
assert len(embeddings) == len(self.unique_nodes_batch)
assert False not in [nodes[i]==self.unique_nodes_batch[i] for i in range(len(nodes))]
node2index = {n:i for i,n in enumerate(self.unique_nodes_batch)}
nodes_score = []
assert len(self.node_positive_pairs) == len(self.node_negtive_pairs)
for node in self.node_positive_pairs:
pps = self.node_positive_pairs[node]
nps = self.node_negtive_pairs[node]
if len(pps) == 0 or len(nps) == 0:
continue
# Q * Exception(negative score)
indexs = [list(x) for x in zip(*nps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
neg_score = self.Q*torch.mean(torch.log(torch.sigmoid(-neg_score)), 0)
# multiple positive score
indexs = [list(x) for x in zip(*pps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
pos_score = torch.log(torch.sigmoid(pos_score))
nodes_score.append(torch.mean(- pos_score - neg_score).view(1,-1))
loss = torch.mean(torch.cat(nodes_score, 0))
return loss
def extend_nodes(self, nodes, num_neg=6):
self.positive_pairs = []
self.node_positive_pairs = {}
self.negtive_pairs = []
self.node_negtive_pairs = {}
self.target_nodes = nodes
self.get_positive_nodes(nodes)
self.get_negtive_nodes(nodes, num_neg)
self.unique_nodes_batch = list(set([i for x in self.positive_pairs for i in x]) | set([i for x in self.negtive_pairs for i in x]))
assert set(self.target_nodes) < set(self.unique_nodes_batch)
return self.unique_nodes_batch
def get_positive_nodes(self, nodes):
return self._run_random_walks(nodes)
def get_negtive_nodes(self, nodes, num_neg):
for node in nodes:
neighbors = set([node])
frontier = set([node])
for i in range(self.N_WALK_LEN):
current = set()
for outer in frontier:
current |= self.adj_lists[int(outer)]
frontier = current - neighbors
neighbors |= current
far_nodes = set(self.train_nodes) - neighbors
neg_samples = random.sample(far_nodes, num_neg) if num_neg < len(far_nodes) else far_nodes
self.negtive_pairs.extend([(node, neg_node) for neg_node in neg_samples])
self.node_negtive_pairs[node] = [(node, neg_node) for neg_node in neg_samples]
return self.negtive_pairs
def _run_random_walks(self, nodes):
for node in nodes:
if len(self.adj_lists[int(node)]) == 0:
continue
cur_pairs = []
for i in range(self.N_WALKS):
curr_node = node
for j in range(self.WALK_LEN):
neighs = self.adj_lists[int(curr_node)]
next_node = random.choice(list(neighs))
# self co-occurrences are useless
if next_node != node and next_node in self.train_nodes:
self.positive_pairs.append((node,next_node))
cur_pairs.append((node,next_node))
curr_node = next_node
self.node_positive_pairs[node] = cur_pairs
return self.positive_pairs