Python 3.7.3(default, Mar 272019,22:11:17)[GCC 7.3.0]:: Anaconda, Inc. on linux
Type "help","copyright","credits"or"license"for more information.>>>import fasttext
>>>
__label__sauce __label__cheese How much does potato starch affect a cheese sauce recipe?
__label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
__label__cast-iron __label__stove How do I cover up the white spots on my cast iron stove?
__label__restaurant Michelin Three Star Restaurant; but if the chef is not there
__label__knife-skills __label__dicing Without knife skills, how can I quickly and accurately dice vegetables?
__label__storage-method __label__equipment __label__bread What's the purpose of a bread box?
__label__baking __label__food-safety __label__substitutions __label__peanuts how to seperate peanut oil from roasted peanuts at home?
__label__chocolate American equivalent for British chocolate terms
__label__baking __label__oven __label__convection Fan bake vs bake
__label__sauce __label__storage-lifetime __label__acidity __label__mayonnaise Regulation and balancing of readymade packed mayonnaise and other sauces
数据说明: * cooking.stackexchange.txt中的每一行都包含一个标签列表,后跟相应的文档, 标签列表以类似"__label__sauce __label__cheese"的形式展现, 代表有两个标签sauce和cheese, 所有标签__label__均以前缀开头,这是fastText识别标签或单词的方式. 标签之后的一段话就是文本信息.如: How much does potato starch affect a cheese sauce recipe?
# 代码运行在python解释器中# 导入fasttext>>>import fasttext
# 使用fasttext的train_supervised方法进行文本分类模型的训练>>> model = fasttext.train_supervised(input="cooking.train")# 获得结果
Read 0M words
# 不重复的词汇总数
Number of words:14543# 标签总数
Number of labels:735# Progress: 训练进度, 因为我们这里显示的是最后的训练完成信息, 所以进度是100%# words/sec/thread: 每个线程每秒处理的平均词汇数# lr: 当前的学习率, 因为训练完成所以学习率是0# avg.loss: 训练过程的平均损失# ETA: 预计剩余训练时间, 因为已训练完成所以是0
Progress:100.0% words/sec/thread:60162 lr:0.000000 avg.loss:10.056812 ETA: 0h 0m 0s
第四步: 使用模型进行预测并评估
# 使用模型预测一段输入文本, 通过我们常识, 可知预测是正确的, 但是对应预测概率并不大>>> model.predict("Which baking dish is best to bake a banana bread ?")# 元组中的第一项代表标签, 第二项代表对应的概率(('__label__baking',), array([0.06550845]))
# 通过我们常识可知预测是错误的
>>> model.predict("Why not put knives in the dishwasher?")
(('__label__food-safety',), array([0.07541209]))
# 通过查看数据, 我们发现数据中存在许多标点符号与单词相连以及大小写不统一,
# 这些因素对我们最终的分类目标没有益处, 反是增加了模型提取分类规律的难度,
# 因此我们选择将它们去除或转化
# 处理前的部分数据
__label__fish Arctic char available in North-America
__label__pasta __label__salt __label__boiling When cooking pasta in salted water how much of the salt is absorbed?
__label__coffee Emergency Coffee via Chocolate Covered Coffee Beans?
__label__cake Non-beet alternatives to standard red food dye
__label__cheese __label__lentils Could cheese "halt" the tenderness of cooking lentils?
__label__asian-cuisine __label__chili-peppers __label__kimchi __label__korean-cuisine What kind of peppers are used in Gochugaru ()?
__label__consistency Pavlova Roll failure
__label__eggs __label__bread What qualities should I be looking for when making the best French Toast?
__label__meat __label__flour __label__stews __label__braising Coating meat in flour before browning, bad idea?
__label__food-safety Raw roast beef on the edge of safe?
__label__pork __label__food-identification How do I determine the cut of a pork steak prior to purchasing it?
# 处理后的部分数据
__label__fish arctic char available in north-america
__label__pasta __label__salt __label__boiling when cooking pasta in salted water how much of the salt is absorbed ?
__label__coffee emergency coffee via chocolate covered coffee beans ?
__label__cake non-beet alternatives to standard red food dye
__label__cheese __label__lentils could cheese "halt" the tenderness of cooking lentils ?
__label__asian-cuisine __label__chili-peppers __label__kimchi __label__korean-cuisine what kind of peppers are used in gochugaru ( ) ?
__label__consistency pavlova roll failure
__label__eggs __label__bread what qualities should i be looking for when making the best french toast ?
__label__meat __label__flour __label__stews __label__braising coating meat in flour before browning , bad idea ?
__label__food-safety raw roast beef on the edge of safe ?
__label__pork __label__food-identification how do i determine the cut of a pork steak prior to purchasing it ?
数据处理后进行训练并测试:
# 重新训练>>> model = fasttext.train_supervised(input="cooking.train")
Read 0M words
# 不重复的词汇总数减少很多, 因为之前会把带大写字母或者与标点符号相连接的单词都认为是新的单词
Number of words:8952
Number of labels:735# 我们看到平均损失有所下降
Progress:100.0% words/sec/thread:65737 lr:0.000000 avg.loss:9.966091 ETA: 0h 0m 0s
# 重新测试>>> model.test("cooking.valid")# 我们看到精度和召回率都有所提升(3000,0.161,0.06962663975782038)
增加训练轮数:
# 设置train_supervised方法中的参数epoch来增加训练轮数, 默认的轮数是5次# 增加轮数意味着模型能够有更多机会在有限数据中调整分类规律, 当然这也会增加训练时间>>> model = fasttext.train_supervised(input="cooking.train", epoch=25)
Read 0M words
Number of words