- 模型结构图
- 模型实现
Skip-gram模型# code by Tae Hwan Jung @graykode modified by 前行follow import numpy as np import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt def random_batch(): random_inputs = [] random_labels = [] random_index = np.random.choice(range(len(skip_grams)), batch_size, replace=False) for i in random_index: random_inputs.append(np.eye(voc_size)[skip_grams[i][0]]) # target random_labels.append(skip_grams[i][1]) # context word return random_inputs, random_labels # Model class Word2Vec(nn.Module): def __init__(self): super(Word2Vec, self).__init__() # W and WT is not Traspose relationship self.W = nn.Linear(voc_size, embedding_size, bias=False) # voc_size > embedding_size Weight self.WT = nn.Linear(embedding_size, voc_size, bias=False) # embedding_size > voc_size Weight def forward(self, X): # X : [batch_size, voc_size] hidden_layer = self.W(X) # hidden_layer : [batch_size, embedding_size] output_layer = self.WT(hidden_layer) # output_layer : [batch_size, voc_size] return output_layer if __name__ == '__main__': batch_size = 2 # mini-batch size embedding_size = 2 # embedding size sentences = ["apple banana fruit", "banana orange fruit", "orange banana fruit", "dog cat animal", "cat monkey animal", "monkey dog animal"] word_sequence = " ".join(sentences).split() word_list = " ".join(sentences).split() word_list = list(set(word_list)) word_dict = { w: i for i, w in
Word2Vec源码解读(Pytorch版本)
最新推荐文章于 2024-06-13 10:02:00 发布
本文详细解读了使用Pytorch实现的Word2Vec模型,涵盖了Skip-gram和CBOW两种模型的实现过程,包括模型结构和运行结果,帮助读者深入理解Word2Vec的工作原理。
摘要由CSDN通过智能技术生成