deepwalk 代码实战
Paper原理参考
class DeepWalk:
def __init__(self, graph, walk_length, num_walks, workers=1):
self.graph = graph
self.w2v_model = None
self._embeddings = {}
self.walker = RandomWalker(
graph, p=1, q=1, )
self.sentences = self.walker.simulate_walks(
num_walks=num_walks, walk_length=walk_length, workers=workers, verbose=1)
def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs):
kwargs["sentences"] = self.sentences
kwargs["min_count"] = kwargs.get("min_count", 0)
kwargs["size"] = embed_size
kwargs["sg"] = 1
kwargs["hs"] = 1
kwargs["workers"] = workers
kwargs["window"] = window_size
kwargs["iter"] = iter
print("Learning embedding vectors...")
model = Word2Vec(**kwargs)
print("Learning embedding vectors done!")
self.w2v_model = model
return model
def get_embeddings(self,):
if self.w2v_model is None:
print("model not train")
return {}
self._embeddings = {}
for word in self.graph.nodes():
self._embeddings[word] = self.w2v_model.wv[word]
return self._embeddings
def deepwalk_walk(self, walk_length, start_node):
walk = [start_node]
while len(walk) < walk_length:
cur = walk[-1]
cur_nbrs = list(self.G.neighbors(cur))
if len(cur_nbrs) > 0:
walk.append(random.choice(cur_nbrs))
else:
break
return walk