fasttext工具安装
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
# 使用pip安装python中的fasttext工具包
$ sudo pip3 install .
RuntimeError: Unsupported compiler -- at least C++0x support is needed!
解决
yum install gcc-c++
error: command 'gcc' failed with exit status 1
解决
yum install python-devel
采集获取数据
获取烹饪相关的数据集, 它是由facebook AI实验室提供的演示数据集,该数据集cooking.stackexchange.txt中的每一行都包含一个标签列表,后跟相应的文档
下载获取:
wget https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz && tar xvzf cooking.stackexchange.tar.gz
数据格式:所有标签均以__label__前缀开头,这是fastText识别标签或单词的方式. 标签之后的一段话就是文本信息
__label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
训练集,验证集的划分
训练数据
head -n 12404 cooking.stackexchange.txt > cooking.train
验证数据
tail -n 3000 cooking.stackexchange.txt > cooking.valid
模型训练
>>> model=fasttext.train_supervised(input='cooking.train',epoch=50,lr=1.5,wordNgrams=2)
Read 0M words
Number of words: 14543
Number of labels: 735
Progress: 100.0% words/sec/thread: 80697 lr: 0.000000 loss: 4.606543 ETA: 0h 0m
使用模型预测评估
>>> model.predict('rr')
# 元组中的第一项代表标签, 第二项代表对应的概率
(('__label__toasting',), array([1.00012004]))
模型优化
- 数据处理
- 增加训练轮数
- 学习率大小调整
- 增加n-gram特征 wordNgram n-gram特征帮助模型捕捉前后词汇之间的关联, 更好的提取分类规则用于模型分类
- 损失计算方式 loss 换一种低复杂度的方式来计算损失
>>> model=fasttext.train_supervised(input='cooking.train',epoch=50,lr=1.5,wordNgrams=2,loss='hs')
Read 0M words
Number of words: 14543
Number of labels: 735
Progress: 100.0% words/sec/thread: 1551340 lr: 0.000000 loss: 21.357746 ETA: 0h 0m
-
多标签多分类问题的损失计算方式
针对多标签多分类问题, 使用’softmax’或者’hs’有时并不是最佳选择, 我们最终得到的应该是多个标签, 而softmax却只能最大化一个标签. 所以选择的损失计算方式为’ova’表示one vs all.
>>> model = fasttext.train_supervised(input="cooking.train", lr=0.2, epoch=50, wordNgrams=2, loss='ova')
Read 0M words
Number of words: 14543
Number of labels: 735
Progress: 100.0% words/sec/thread: 78905 lr: 0.000000 loss: 8.353354 ETA: 0h 0m
- model.predict
k:代表指定模型输出多少个标签, 默认为1, 这里设置为-1, 意味着尽可能多的输出
threshold:代表显示的标签概率阈值, 设置为0.5, 意味着显示概率大于0.5的标签
>>> model.predict("Which baking dish is best to bake a banana bread ?", k=-1, threshold=0.5)
(('__label__baking', '__label__bread'), array([1.00001001, 0.93629503]))
模型保存与加载
模型保存
model.save_model('./model.bin')
模型加载
fasttext.load_model('./model.bin')