1 安装fastText
facebook参考地址
https://github.com/facebookresearch/fastText
fastText安装包
https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext
使用tar文件安装比较麻烦,建议使用whl安装
pip install fasttext‑0.9.2‑cp38‑cp38‑win_amd64.whl
开发文档
# python开发文档
https://fasttext.cc/docs/en/python-module.html
# js开发文档
https://fasttext.cc/docs/en/webassembly-module.html
2 源文件
import fasttext
# 消除警告
# Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar
fasttext.FastText.eprint = lambda x: None
# 分类模型名称
classifier_model_name = "model_classify.bin"
# 训练模型
def train_model():
# 参数标准的取值范围: lr =[0.1, 1.0], epoch=[5-50], wordNgrams=[1-5]
# loss参数
# 实现多标签分类时,loss=ova,ova表示one-vs-all
# 当数据量比较大时,loss=hs,hs表示hierarchical softmax
model = fasttext.train_supervised("train.txt", lr=0.1, epoch=25, wordNgrams=4, loss='softmax', label_prefix='__label__')
# 保存模型
model.save_model(classifier_model_name)
# 测试模型
def test_model():
# 加载模型
classifier = fasttext.load_model(classifier_model_name)
# 测试数据
res_test = classifier.test("test.txt")
print("数据量:", res_test[0])
print("准确率:", res_test[1])
print("召回率:", res_test[2])
predict_file = open('predict.txt', 'w', encoding='utf-8')
with open('test.txt', encoding='utf-8') as fp:
# 每行数据的格式:标签+文本,标签由’__label__‘+类别组成
for line in fp.readlines():
line = line.strip()
# 预测结果, 原始数据
predict_file.write(classifier.predict(line)[0][0] + ',\t' + line + '\n')
predict_file.close()
# 预测模型
def predict_text():
classifier = fasttext.load_model(classifier_model_name)
# text是预测的文本列表, k表示输出标签的数量,-1表示全部输出
res_predict = classifier.predict(text=["新冠肺炎", "智能发展"], k=-1)
print("概率列表:", res_predict)
# text是预测的文本,默认返回相似度最大的标签和概率
res_predict = classifier.predict(text="疫情")
print("概率:", res_predict)
if __name__ == '__main__':
train_model()
test_model()
predict_text()
3 数据格式
训练样本
__label__0 智能科学发展,人工智能科学
__label__1 新冠肺炎疫情
__label__0 深度学习技术,机器学习技术
__label__1 疫情得到有效控制
测试样本
__label__0 中国在人工智能领域得到了一定的发展
__label__1 中国有效控制了疫情