fasttext文本分类实战

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/IT_xiao_bai/article/details/86629543

一、安装fastext

fastText在Python2和Python3中都可以使用,已有了现成的包,但只支持Linux和mac系统,windows暂时还不支持fastText。PS.本人尝试在windows上安装过fasttext,没安装成功,又花了很长时间找解决方案,尝试无果后,果断在linux系统中安装,并成功运行。

使用命令:pip install fasttext即可安装,安装过程中若出现以下几种错误,可用对应的解决方案重新安装

1.出现ModuleNotFoundError: No module named 'Cython’错误:
在这里插入图片描述
解决方案:用pip安装Cython库就可以,命令为:pip install Cython

2.出现gcc: error trying to exec ‘cc1plus’: execvp: No such file or directory错误:
在这里插入图片描述
解决方案:没有安装g++,安装一下g++就可以了,然后用g++ -v,gcc -v分别查看一下gcc和g++的版本是否一致,安装好g++后,再重新安装fasttext就可以了。

二、数据准备

需要注意的是,训练的数据文件中类别的前/后缀默认是"_label_",类别和句子之间用tab键分开,句子可以是按照词分割开的,也可以是按照字分割开的。数据样例如下图所示:
在这里插入图片描述

三、训练模型

产生的模型文件会被保存成以.bin结尾的文件。

fasttext参数如下:

The following arguments are optional:
-lr            learning rate [0.05]
-lrUpdateRate  change the rate of updates for the learning rate [100]
-dim           size of word vectors [100]
-ws            size of the context window [5]
-epoch         number of epochs [5]
-minCount      minimal number of word occurences [1]
-neg           number of negatives sampled [5]
-wordNgrams    max length of word ngram [1]
-loss          loss function {ns, hs, softmax} [ns]
-bucket        number of buckets [2000000]
-minn          min length of char ngram [3]
-maxn          max length of char ngram [6]
-thread        number of threads [12]
-t             sampling threshold [0.0001]
-label         labels prefix [__label__]

训练代码:

import fasttext
import os
import sys
from sklearn import metrics

# 这里是为了中文乱码做的一些转码
if sys.version_info[0] > 2:
    is_py3 = True
else:
    is_py3 = False


def native_content(content):
    if not is_py3:
        return content.decode('utf-8')
    else:
        return content


# 这个方法是读取文件,主要是训练集和测试集
def open_file(filename, mode='r'):
    if is_py3:
        return open(filename, mode, encoding='utf-8', errors='ignore')
    else:
        return open(filename, mode)


# 这个方法是切割数据,以“\t”进行分割,返回内容和对应标签的2个list
def read_file(filename):
    """读取文件数据"""
    contents, labels = [], []
    with open_file(filename) as f:
        for line in f:
            try:
                content, label = line.strip().split('\t')
                if content:
                    contents.append(native_content(content))
                    labels.append(native_content(label))
            except:
                pass
    return contents, labels


if __name__ == '__main__':
    # 判断输入参数是否含有语料路径的参数
    if len(sys.argv) != 2:
        print("The number of parameters is not correct!")
        exit()

    filename = sys.argv[1]
    print("input param:%s" % filename)
    print(os.path.exists('./%s/model.bin' % filename))

    classifier = None
    # 判断是否已经存在模型,如果存在则加载,不存在则进行训练
    if os.path.exists('./model/fasttext_model.bin'):
        classifier = fasttext.load_model('./model/fasttext_model.bin')
    else:
        # 训练模型
        if not os.path.exists('./%s/' % filename):
            os.mkdir('./%s/' % filename)
        print("正在训练模型")
        fasttext.supervised('./%s/train.txt' % filename, './model/fasttext_model')  # 模型自动以.bin后缀名保存
        print("训练完成")

运行程序:python fasttext_textclassify.py data/
data为训练数据的路径,数据在data/文件夹下。

四、预测

预测之前需要先加载相关模型,加载模型若出现以下错误:
在这里插入图片描述
解决方案:先查看是否将模型路径写错,若写错修改正确的路径;若发现路径没错误,那可能是因为训练模型太大,加载不出来,可以使用pyfasttext库试一下。

预测代码:

test_sentence = input("Please input the test sentence:")
classifier = fasttext.load_model('./model/fasttext_model.bin', label_prefix='__label__')
predict_list = []
predict_list.append(' '.join(jieba.lcut(test_sentence)))
pred = classifier.predict(predict_list)
print("分类为:" + str(pred))
# 输出预测结果以及类别概率
print(classifier.predict_proba(test_sentence))
# 得到前k个类别
lables_top3 = classifier.predict(texts,k=3)
print(lables_top3)
# 得到前k个类别+概率
lables_prob3 = classifier.predict_proba(texts,k=3)
print(lables_prob3)

结果如下图:
在这里插入图片描述

展开阅读全文

没有更多推荐了,返回首页