fasttext工具进行文本分类

fasttext工具安装
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
# 使用pip安装python中的fasttext工具包
$ sudo pip3 install .

RuntimeError: Unsupported compiler -- at least C++0x support is needed!

解决
yum install gcc-c++

error: command 'gcc' failed with exit status 1

解决
yum install python-devel
采集获取数据

获取烹饪相关的数据集, 它是由facebook AI实验室提供的演示数据集,该数据集cooking.stackexchange.txt中的每一行都包含一个标签列表,后跟相应的文档

下载获取:
wget https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz && tar xvzf cooking.stackexchange.tar.gz

数据格式:所有标签均以__label__前缀开头,这是fastText识别标签或单词的方式. 标签之后的一段话就是文本信息

__label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
训练集,验证集的划分
训练数据
head -n 12404 cooking.stackexchange.txt > cooking.train
验证数据
tail -n 3000 cooking.stackexchange.txt > cooking.valid
模型训练
>>> model=fasttext.train_supervised(input='cooking.train',epoch=50,lr=1.5,wordNgrams=2)

Read 0M words
Number of words:  14543
Number of labels: 735
Progress: 100.0% words/sec/thread:   80697 lr:  0.000000 loss:  4.606543 ETA:   0h 0m

使用模型预测评估
>>> model.predict('rr')
# 元组中的第一项代表标签, 第二项代表对应的概率
(('__label__toasting',), array([1.00012004]))
模型优化
  1. 数据处理
  2. 增加训练轮数
  3. 学习率大小调整
  4. 增加n-gram特征 wordNgram n-gram特征帮助模型捕捉前后词汇之间的关联, 更好的提取分类规则用于模型分类
  5. 损失计算方式 loss 换一种低复杂度的方式来计算损失
>>> model=fasttext.train_supervised(input='cooking.train',epoch=50,lr=1.5,wordNgrams=2,loss='hs')
Read 0M words
Number of words:  14543
Number of labels: 735
Progress: 100.0% words/sec/thread: 1551340 lr:  0.000000 loss: 21.357746 ETA:   0h 0m

  1. 多标签多分类问题的损失计算方式

    针对多标签多分类问题, 使用’softmax’或者’hs’有时并不是最佳选择, 我们最终得到的应该是多个标签, 而softmax却只能最大化一个标签. 所以选择的损失计算方式为’ova’表示one vs all.

>>> model = fasttext.train_supervised(input="cooking.train", lr=0.2, epoch=50, wordNgrams=2, loss='ova')
Read 0M words
Number of words:  14543
Number of labels: 735
Progress: 100.0% words/sec/thread:   78905 lr:  0.000000 loss:  8.353354 ETA:   0h 0m

  • model.predict
    k:代表指定模型输出多少个标签, 默认为1, 这里设置为-1, 意味着尽可能多的输出
    threshold:代表显示的标签概率阈值, 设置为0.5, 意味着显示概率大于0.5的标签
>>> model.predict("Which baking dish is best to bake a banana bread ?", k=-1, threshold=0.5)
(('__label__baking', '__label__bread'), array([1.00001001, 0.93629503]))

模型保存与加载
模型保存
model.save_model('./model.bin')
模型加载
fasttext.load_model('./model.bin')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值