基于pytorch实现Word2Vec(skip-gram+Negative Sampling)

目录

word2vec简介

语料处理

数据预处理

训练模型

近似训练法

参数设定

预测及可视化


word2vec简介

2013 年,Google 团队发表了 word2vec 工具。word2vec 工具主要包含两个模型:跳字模型(skip-gram)和连续词模型(continuous bag of words,简称 CBOW),以及两种高效训练的方法:负采样(negative sampling)和层序 softmax(hierarchical softmax)。
类似于f(x)->y,Word2vec 的最终目的,不是要把 f 训练得多么完美,而是只关心模型训练完后的副产物——模型参数(这里特指神经网络的权重),并将这些参数,作为输入 x 的某种向量化的表示,这个向量便叫做——词向量。
word2vec 词向量可以较好地表达不同词之间的相似度和类比关系。


语料处理

步骤:

  1. 使用 re 的 findall 方法以及正则表达式去除标点符号;
  2. 使用 jieba 进行分词;
  3. 使用停用词典剔除无意义的词。

处理前:

处理后:

代码如下:

import re
import jieba

stopwords = {}
fstop = open('stop_words.txt', 'r', encoding='utf-8', errors='ingnore')
for eachWord in fstop:
    stopwords[eachWord.strip()] = eachWord.strip()  # 创建停用词典
fstop.close()

f1 = open('红楼梦.txt', 'r', encoding='utf-8', errors='ignore')
f2 = open('红楼梦_p.txt', 'w', encoding='utf-8')

line = f1.readline()
while line:
    line = line.strip()  # 去前后的空格
    if line.isspace():  # 跳过空行
        line = f1.readline()

    line = re.findall('[\u4e00-\u9fa5]+', line)  # 去除标点符号
    line = "".join(line)

    seg_list = jieba.cut(line, cut_all=False)  # 结巴分词

    outStr = ""
    for word in seg_list:
        if word not in stopwords:  # 去除停用词
            outStr += word
            outStr += " "

    if outStr:  # 不为空添加换行符
        outStr = outStr.strip() + '\n'

    f2.writelines(outStr)
    line = f1.readline()

f1.close()
f2.close()

数据预处理

步骤:

  1. 剔除低频词;
  2. 生成 id 到 word、word 到 id 的映射;
  3. 使用 subsampling 处理语料;
  4. 定义获取正、负样本方法;
  5. 估计数据中正采样对数。

测试结果:

 这里 min_count=1 也就是不剔除低频词,窗口大小设定为2,负样本数量 k 设定为3。

代码如下:

import math
import numpy
from collections import deque
from numpy import random

numpy.random.seed(6)


class InputData:

    def __init__(self, file_name, min_count):
        self.input_file_name = file_name
        self.get_words(min_count)
        self.word_pair_catch = deque()  # deque为队列,用来读取数据
        self.init_sample_table()  # 采样表
        print('Word Count: %d' % len(self.word2id))
        print("Sentence_Count:", self.sentence_count)
        print("Sentence_Length:", self.sentence_length)

    def get_words(self, min_count):  # 剔除低频词,生成id到word、word到id的映射
        self.input_file = open(self.input_file_name, encoding="utf-8")
        self.sentence_length = 0
        self.sentence_count = 0
        word_frequency = dict()
        for line in self.input_file:
            self.sentence_count += 1
            line = line.strip().split(' ')  # strip()去除首尾空格,split(' ')按空格划分词
            self.sentence_length += len(line)
            for w in line:
                try:
                    word_frequency[w] += 1
                except:
                    word_frequency[w] = 1
        self.word2id = dict()
        self.id2word = dict()
        wid = 0
        self.word_frequency = dict()
        for w, c in word_fre
  • 6
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值