目录
一、Word2Vec模型
1.skip-gram
已知中心词预测上下文词。(推断上下)
2.cbow
已知上下文词预测中心词。(补空)
二、AG新闻数据集加载
1.AG新闻数据集
- 学术社区ComeToMyHead从2000多不同的新闻来源搜集的超过一百万的新闻文章。
- 用于 分类、聚类、信息获取(排序、搜索)
- 四类:World、Sports、Business、Sci、Tec
- 第一列:新闻分类class /// 第二列:新闻标题title /// 第三列:新闻正文context
- 127600个样本:120000训练集 + 7600测试集
2.AG新闻数据集下载
下载地址:AG News 新闻文章数据集 - 数据集下载 - 超神经
3.查看数据
例:最后三条数据
三、停用词表
1.NLTK工具包下载
pip install -U nltk
2.NLTK控制台
在终端进入NLTK,下载控制台
3.在控制台下载停用词
"Corpora"-->"stopwords"
注意下载路径---最好不要改昂
四、数据预处理
- 删除标点符号
- 统一改为小写
- 替换多重空格
- 以空格分词
- 去除停用词
import csv
import re
import numpy as np
import nltk
from nltk.stem.porter import *
stoplist = nltk.corpus.stopwords.words('english')
"""--------1.导入数据--------"""
agnews_train = csv.reader(open("./dataset/train.csv",'r'))
print(agnews_train)
agnews_label = []
agnews_title = []
agnews_text = []
"""--------2.数据预处理-------"""
def text_clear(text):
text = text.lower()
text = re.sub(r"[^a-z0-9]"," ",text)
text = re.sub(r" +"," ",text)
text = text.strip()
text = text.split(" ")
text = [word for word in text if word not in stoplist]
text = [PorterStemmer().stem(word) for word in text]
text.append("eos")
text = ["bos"] + text
return text
for line in agnews_train:
# print(line,"\n")
agnews_label.append(line[0])
agnews_title.append(text_clear(line[1]))
agnews_text.append(text_clear(line[2]))
print(agnews_text,"\n")
结果:预处理之后 最后三条数据
五、训练Word2Vec模型
1.训练模型
"""--------3.模型训练-------"""
from gensim.models import word2vec
model = word2vec.Word2Vec(agnews_text,vector_size = 64,min_count=0,window = 5, epochs = 128)
print(model)
模型超参数:
- min_count(int,optional):忽略词频小于该值的词
- vector_size(int,optional):word向量的维度
- window(int,optional):一个句子中当前词和被预测单词的最大距离
- workers(int,optional):训练模型时使用的线程数
2.保存模型
"""--------4.保存模型-------"""
model_name = "corpusWord2Vec.bin"
model.save(model_name)
3.复用模型
"""--------5.复用模型-------"""
from gensim.models import word2vec
model = word2vec.Word2Vec.load('./corpusWord2Vec.bin')
model.train(agnews_title,epochs= model.epochs,total_examples=model.corpus_count)
(1) 获得所有词汇组:key_to_index
print(model.wv.key_to_index)
输出结果部分展示:
(2)获取某个词汇的向量:model.wv["word"]
print("year:",model.wv["year"])
(3)计算两个词之间的余弦相似:similarity
print("canada_malaysia:",model.wv.similarity("canada","malaysia"))
(4)找近义词:most_similar
print("canada_most_similar:",model.wv.most_similar("canada",topn=5))
(5)找不合群的词:doesnt_match
print("doesnt_match:",model.wv.doesnt_match("buy malaysia australia aviva".split()))
(6)给定上下文词,获得中心词概率分布
print(model.predict_output_word(['sunday', 'price', 'stock'], topn=10))