python-pytorch实现CBOW 0.5.000

数据加载、切词

按照链接https://blog.csdn.net/m0_60688978/article/details/137538274操作后,可以获得的数据如下

  1. wordList 文本中所有的分词,放入这个数组中
  2. raw_text 这个可以忽略,相当于wordlist的备份,防止数据污染了
  3. vocab 将wordList转变为set,即set(wordList)
  4. vocab_size 所有分词的个数
  5. word_to_idx 字典格式,汉字对应索引
  6. idx_to_word 字典格式,索引对应汉字

准备训练数据

data3 = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    data3 .append((context, target))
 
print(data3 [:5])
"""
[(['从零开始', 'Zookeeper', '高', '可靠'], '开源'), (['Zookeeper', '开源', '可靠', '分布式'], '高'), (['开源', '高', '分布式', '一致性'], '可靠'), (['高', '可靠', '一致性', '协调'], '分布式'), (['可靠', '分布式', '协调', '服务'], '一致性')]
"""

准备模型和参数

# 超参数
learning_rate = 0.003
device = torch.device('cpu')
embedding_dim = 100
epoch = 10
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.proj = nn.Linear(embedding_dim, 128)
        self.output = nn.Linear(128, vocab_size)
 
    def forward(self, inputs):
        embeds = sum(self.embeddings(inputs)).view(1, -1)
        out = F.relu(self.proj(embeds))
        out = self.output(out)
        nll_prob = F.log_softmax(out, dim=-1)
        return nll_prob
 
model = CBOW(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练

losses = []
loss_function = nn.NLLLoss()
 
for epoch in trange(3000):
    total_loss = 0
    for context, target in data1:
        context_vector = make_context_vector(context, word_to_idx)
        target = torch.tensor([word_to_idx[target]])
        # 梯度清零
        model.zero_grad()
        # 开始前向传播
        train_predict = model(context_vector) 
        loss = loss_function(train_predict, target)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        total_loss += loss.item()
    if epoch % 100 ==0:
            print("loss is ",total_loss,"echo is ",epoch)
        
    losses.append(total_loss)
print("losses-=", losses)
"""
 97%|███████████████████████████████████████████████████████████████████████████▍  | 2902/3000 [07:07<00:13,  7.17it/s]
loss is  0.18700819212244824 echo is  2900
100%|██████████████████████████████████████████████████████████████████████████████| 3000/3000 [07:21<00:00,  6.79it/s]
"""

保存模型

torch.save(model.state_dict(),"model.pth")

加载模型

model = CBOW(vocab_size, embedding_dim).to(device)
model.load_state_dict(torch.load("model.pth"))
print(model)
"""
CBOW(
  (embeddings): Embedding(179, 100)
  (proj): Linear(in_features=100, out_features=128, bias=True)
  (output): Linear(in_features=128, out_features=179, bias=True)
)
"""

简单预测

def cut_sentense(str):
    stop_words = load_stop_words()
    with open('data/zh.txt', encoding='utf8') as f:
        allData = f.readlines()
    result = []
    c_words = jieba.lcut(str)

    for word in c_words:
        if word not in stop_words and word != "\n":
            result.append(word)
    return result
    
context_vector = make_context_vector(cut_sentense("在Master节点使用客户端"), word_to_idx).to(device)
print(context_vector,type(context_vector))
predict = model(context_vector).data.cpu().numpy()
max_idx = np.argmax(predict)
# 输出预测的值
print('Prediction: {}'.format(idx_to_word[max_idx]))

"""
输出中心词语,看上去不怎么样
tensor([120,  37,  49]) <class 'torch.Tensor'>
Prediction: 除主

获取词向量

trained_vector_dic={}
for word, idx in word_to_idx.items(): # 输出每个词的嵌入向量
    trained_vector_dic[word]=model.embedding.weight[idx]
"""
trained_vector_dic内容类似于
{'参数值': tensor([-3.6921e+00, -1.3388e+00,  2.4545e-03, -1.1352e+00, -1.8306e-04,
         -6.3501e-01, -1.4372e-01, -8.2283e-01, -1.6009e+00, -7.4731e-01,
         -1.3509e-01, -2.5100e-01, -1.0037e+00,  9.0061e-01,  1.7794e-01,
         -8.6344e-03, -1.2831e+00, -2.1400e+00,  2.7457e-01,  1.8157e-01,
          2.1480e-01, -2.2192e-02, -3.8433e-01,  1.3575e+00,  1.8483e+00,
         -6.6326e-01, -2.0239e+00, -1.9854e+00,  4.0531e-01, -1.5659e-01,
         -2.7774e+00, -8.2578e-02,  1.5725e+00, -9.9693e-01,  6.0748e-01,
         -6.4992e-01,  8.5653e-01, -1.1889e+00,  1.1657e-04, -3.3866e-01,
          8.2302e-02,  1.0612e-02, -8.8592e-01, -1.9495e-01, -1.2271e-01,
         -4.1997e+00,  1.3430e+00, -6.6779e-01, -1.7927e-01,  3.0450e-01,
          8.4677e-02, -9.5100e-01,  2.5847e-01,  1.1187e+00,  3.1471e+00,
          2.4095e+00, -1.0612e-01,  2.1663e+00, -8.5172e-01, -2.1438e-01,
          2.3635e-01,  4.7740e-01, -2.8115e+00, -1.5964e-01,  4.9957e-02,
          1.6154e-01, -7.0892e-01, -5.6724e-01, -2.2594e-01, -1.2353e+00,
          8.9448e-01, -1.7034e-01, -6.5750e-01,  9.8126e-01, -1.7088e+00,
         -1.9967e-01,  2.6574e-01, -1.3275e-01,  6.1529e-01, -3.6684e-01,
          1.7341e-02,  1.5207e-03, -4.8425e-01, -2.2761e-01, -2.2298e+00,
         -5.5302e-01,  4.4864e-01, -2.5363e-01,  3.4734e-01, -4.4062e-02,
         -1.3769e+00,  1.6567e-01, -7.3674e-01, -8.4163e-01,  2.9937e-01,
          2.3714e+00,  1.2883e+00,  1.2383e-01,  7.5008e-01, -1.3516e-01],
        grad_fn=<SelectBackward0>),
 '05': tensor([ 1.1536e+00, -2.2545e-01, -9.9584e-01,  2.0407e-02,  1.9062e+00,
         -5.5870e-01, -6.1779e-04,  2.7210e-01, -1.9126e+00, -8.1227e-02,
         -6.0733e-02, -3.3426e-03,  9.4838e-01,  3.1968e-01,  1.1331e+00,
          1.9320e-01,  9.8004e-01,  1.3209e-01,  3.9876e-01,  1.9894e-01,
          9.6364e-01, -2.9291e-01, -1.4829e+00,  1.9647e+00, -1.2805e-01,
          1.7458e+00,  9.1834e-02,  7.3453e-01, -1.4541e-01, -1.5197e+00,
          2.5946e-01,  1.1071e+00,  2.3167e-02, -9.9457e-01, -6.4125e-02,
         -2.1326e-01, -2.1815e+00, -8.3949e-02, -3.8223e-01,  2.0616e+00,
         -7.3382e-02,  2.6695e-01,  9.4765e-02, -3.2757e-01, -4.8486e-01,
         -3.0599e-01,  8.8235e-01,  3.1940e-01, -1.3256e-01, -6.0862e-01,
          4.4978e-01, -3.0902e+00,  1.6898e+00,  5.7821e-01, -5.2478e-02,
          4.9577e-01,  4.5494e-01,  5.6485e-04, -2.5271e+00,  3.1652e+00,
         -4.2832e-02, -9.9416e-02,  3.1775e-01, -1.9758e+00, -1.2955e-02,
         -1.6038e+00,  5.3717e-02,  2.9455e-03, -3.6091e-01, -5.7126e-01,
          1.6538e+00, -2.0648e+00, -3.1718e-01, -1.0939e+00,  2.4513e+00,
         -3.5226e-03,  8.0853e-01,  4.0330e-01,  5.2394e-01,  2.7201e+00,
         -2.4086e-01, -3.3241e-01,  2.9677e+00, -2.2749e-01,  3.1172e+00,
          7.8760e-02, -1.0339e+00,  1.4011e+00,  5.2701e-01,  8.9391e-01,
          2.2373e-01,  1.3236e+00, -6.5663e-02,  8.7556e-01,  2.3522e+00,
         -2.2826e-01, -1.4658e-01, -1.8229e+00, -6.5210e-01,  4.1831e-04],
        grad_fn=<SelectBackward0>),
 'HOME': tensor([-1.2881e+00,  9.8371e-01, -1.7626e+00,  6.8964e-02, -1.2208e+00,
         -7.2041e-01,  1.6493e+00,  2.4161e-01,  3.0407e-01,  1.0450e+00,
         -3.7338e-02,  1.2912e+00, -7.8684e-01, -8.1084e-02,  3.1615e+00,
          1.1677e+00, -2.7518e-01,  1.2211e+00,  5.5950e-01, -2.1043e+00,
          5.2210e-01, -1.7408e-01,  5.1499e-02,  7.7797e-01, -1.4519e-03,
         -3.4803e-02, -4.3894e-01, -3.7840e+00,  1.8685e+00,  5.1014e-01,
          2.8481e-04,  7.3540e-01,  4.0983e-02,  1.9889e-01,  2.2323e-01,
         -1.2719e+00,  9.0170e-01, -1.7608e+00,  1.2378e-04,  3.6426e-01,
         -2.3393e-01,  3.9977e-01,  4.6494e-01, -2.2011e+00, -2.1913e-02,
         -2.4567e-04, -2.4916e-01, -9.5079e-01, -2.0207e-01, -7.1489e-02,
         -3.2497e-02, -2.0102e-01,  5.9411e-02, -7.5153e-01, -5.1971e-01,
          2.7858e-01, -1.7449e-01, -2.4816e-02,  6.8960e-01,  1.3359e+00,
          1.4179e+00,  2.1634e-02,  4.1195e-01, -2.4597e+00, -2.2374e+00,
          4.7058e-01, -3.2053e-01,  1.0844e+00, -8.6147e-01,  1.6927e+00,
         -1.0051e-01, -2.3251e+00, -1.3552e+00, -1.3862e+00,  4.0486e-01,
          4.2523e-02, -8.1515e-01,  2.9837e-01, -1.6220e-02,  1.0755e-01,
          3.7893e-01, -1.4399e+00, -2.8273e-01, -1.4445e-01,  3.2650e-01,
          2.5101e+00,  2.7584e-01,  2.6028e-01,  4.5515e-03, -1.3406e+00,
         -6.2879e-02, -3.8538e-01, -1.9729e+00, -1.1987e+00, -1.7349e-01,
         -2.0273e+00,  9.5012e-01,  3.1583e-02,  1.2475e+00,  1.7564e-01],
        grad_fn=<SelectBackward0>)}
"""

降维显示图

这里是参考另外一篇文章见最后的章节

"""
    待转换类型的PyTorch Tensor变量带有梯度,直接将其转换为numpy数据将破坏计算图,
    因此numpy拒绝进行数据转换,实际上这是对开发者的一种提醒。
    如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加detach()调用。
"""
 
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(W)
 
# 降维后在生成一个词嵌入字典,即即{单词1:(维度一,维度二),单词2:(维度一,维度二)...}的格式
word2ReduceDimensionVec = {}
for word in word_to_idx.keys():
    word2ReduceDimensionVec[word] = principalComponents[word_to_idx[word], :]
 
# 将生成的字典写入到文件中,字符集要设定utf8,不然中文乱码
with open("CBOW_ZH_wordvec.txt", 'w', encoding='utf-8') as f:
    for key in word_to_idx.keys():
        f.write('\n')
        f.writelines('"' + str(key) + '":' + str(word_2_vec[key]))
    f.write('\n')
 
# 将词向量可视化
plt.figure(figsize=(20, 20))
# 只画出1000个,太多显示效果很差
count = 0
for word, wordvec in word2ReduceDimensionVec.items():
    if count < 1000:
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号,否则负号会显示成方块
        plt.scatter(wordvec[0], wordvec[1])
        plt.annotate(word, (wordvec[0], wordvec[1]))
        count += 1
plt.show()

在这里插入图片描述

使用词向量计算相似度

参照链接https://blog.csdn.net/m0_60688978/article/details/137535717,第五点

参考

https://blog.csdn.net/Metal1/article/details/132886936
https://blog.csdn.net/L_goodboy/article/details/136347947

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值