一、安装fastext
fastText在Python2和Python3中都可以使用,已有了现成的包,但只支持Linux和mac系统,windows暂时还不支持fastText。PS.本人尝试在windows上安装过fasttext,没安装成功,又花了很长时间找解决方案,尝试无果后,果断在linux系统中安装,并成功运行。
使用命令:pip install fasttext即可安装,安装过程中若出现以下几种错误,可用对应的解决方案重新安装
1.出现ModuleNotFoundError: No module named 'Cython’错误:
解决方案:用pip安装Cython库就可以,命令为:pip install Cython
2.出现gcc: error trying to exec ‘cc1plus’: execvp: No such file or directory错误:
解决方案:没有安装g++,安装一下g++就可以了,然后用g++ -v,gcc -v分别查看一下gcc和g++的版本是否一致,安装好g++后,再重新安装fasttext就可以了。
二、数据准备
需要注意的是,训练的数据文件中类别的前/后缀默认是"_label_",类别和句子之间用tab键分开,句子可以是按照词分割开的,也可以是按照字分割开的。数据样例如下图所示:
三、训练模型
产生的模型文件会被保存成以.bin结尾的文件。
fasttext参数如下:
The following arguments are optional:
-lr learning rate [0.05]
-lrUpdateRate change the rate of updates for the learning rate [100]
-dim size of word vectors [100]
-ws size of the context window [5]
-epoch number of epochs [5]
-minCount minimal number of word occurences [1]
-neg number of negatives sampled [5]
-wordNgrams max length of word ngram [1]
-loss loss function {ns, hs, softmax} [ns]
-bucket number of buckets [2000000]
-minn min length of char ngram [3]
-maxn max length of char ngram [6]
-thread number of threads [12]
-t sampling threshold [0.0001]
-label labels prefix [__label__]
训练代码:
import fasttext
import os
import sys
from sklearn import metrics
# 这里是为了中文乱码做的一些转码
if sys.version_info[0] > 2:
is_py3 = True
else:
is_py3 = False
def native_content(content):
if not is_py3:
return content.decode('utf-8')
else:
return content
# 这个方法是读取文件,主要是训练集和测试集
def open_file(filename, mode='r'):
if is_py3:
return open(filename, mode, encoding='utf-8', errors='ignore')
else:
return open(filename, mode)
# 这个方法是切割数据,以“\t”进行分割,返回内容和对应标签的2个list
def read_file(filename):
"""读取文件数据"""
contents, labels = [], []
with open_file(filename) as f:
for line in f:
try:
content, label = line.strip().split('\t')
if content:
contents.append(native_content(content))
labels.append(native_content(label))
except:
pass
return contents, labels
if __name__ == '__main__':
# 判断输入参数是否含有语料路径的参数
if len(sys.argv) != 2:
print("The number of parameters is not correct!")
exit()
filename = sys.argv[1]
print("input param:%s" % filename)
print(os.path.exists('./%s/model.bin' % filename))
classifier = None
# 判断是否已经存在模型,如果存在则加载,不存在则进行训练
if os.path.exists('./model/fasttext_model.bin'):
classifier = fasttext.load_model('./model/fasttext_model.bin')
else:
# 训练模型
if not os.path.exists('./%s/' % filename):
os.mkdir('./%s/' % filename)
print("正在训练模型")
fasttext.supervised('./%s/train.txt' % filename, './model/fasttext_model') # 模型自动以.bin后缀名保存
print("训练完成")
运行程序:python fasttext_textclassify.py data/
data为训练数据的路径,数据在data/文件夹下。
四、预测
预测之前需要先加载相关模型,加载模型若出现以下错误:
解决方案:先查看是否将模型路径写错,若写错修改正确的路径;若发现路径没错误,那可能是因为训练模型太大,加载不出来,可以使用pyfasttext库试一下。
预测代码:
test_sentence = input("Please input the test sentence:")
classifier = fasttext.load_model('./model/fasttext_model.bin', label_prefix='__label__')
predict_list = []
predict_list.append(' '.join(jieba.lcut(test_sentence)))
pred = classifier.predict(predict_list)
print("分类为:" + str(pred))
# 输出预测结果以及类别概率
print(classifier.predict_proba(test_sentence))
# 得到前k个类别
lables_top3 = classifier.predict(texts,k=3)
print(lables_top3)
# 得到前k个类别+概率
lables_prob3 = classifier.predict_proba(texts,k=3)
print(lables_prob3)
结果如下图: