fasttext文本分类

前言

fastText是Facebook Research在2016年开源的一个词向量及文本分类工具,今天这篇文章主要使用fasttext在来做文本分类,测试fasttext用于分类的实际效果。

本文所使用的数据及代码均已上传至GitHub
传送门: fasttext_classify
由于数据集太大了,无法上传至GitHub,数据集链接:fasttext分类数据集
百度云:链接
提取码:96fu

一、环境

  • python3.8
  • fasttext-0.9.2
  • tqdm

在windows上安装fasttext得去https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext下载对应python版本的whl,然后在命令行使用pip install xxx.whl安装。

二、数据处理

fasttext要求的数据格式有点奇怪,需要将label处理成__label__的格式,假如你原本的label为0,则需要将其处理成__label__0
格式如下:

“并且 在 世界 范围 内 广为流传” + \t + __label__0

这样也可以:

 __label__0 + \t + “并且 在 世界 范围 内 广为流传” 

不知道 \t 换成 , 行不行,感兴趣的同学可以试下

下面进入正题:
首先我们观察下数据:

import pandas as pd
import numpy as np
data = pd.read_csv('./data/train.csv', sep='\t')
list(np.unique(data['label']))
总共35个类别
[' 文化', '中小学教辅','传记', '健身与保健', '农业/林业', '动漫', '励志与成功', '医学', '历史', '哲学/宗教', '国学/古籍','外语学习', '大中专教材教辅', 
'婚恋与两性', '孕产/胎教', '小说', '工业技术', '建筑', '政治/军事', '文学', '旅游/地图', '法律', '烹饪/美食', '社会科学',
 '科学与自然', '科普读物', '童书', '管理', '经济', '考试','育儿/家教', '艺术', '计算机与互联网', '金融与投资', '青春文学']

统计下各标签的数据量:

from collections import Counter
Counter(data['label'])

Counter(data['label']).most_common(5)
Counter({'文学': 13469, '童书': 5996, '大中专教材教辅': 5396, '工业技术': 3292, '中小学教辅': 2603, '
艺术': 2397, '社会科学': 2317, '小说': 2191, '计算机与互联网': 2054, '管理': 1852, '建筑': 1788, '外语
学习': 1494, '历史': 1455, '科学与自然': 1421, '法律': 1256, '政治/军事': 1210, '哲学/宗教': 1012, '医
学': 998, '经济': 938, '励志与成功': 921, '考试': 869, '传记': 761, '青春文学': 746, ' 文化': 707, '农
业/林业': 567, '动漫': 442, '育儿/家教': 390, '烹饪/美食': 375, '国学/古籍': 357, '旅游/地图': 354, '
健身与保健': 348, '科普读物': 329, '孕产/胎教': 301, '金融与投资': 186, '婚恋与两性': 63})
[('文学', 49868),
 ('童书', 18926),
 ('工业技术', 15714),
 ('大中专教材教辅', 12229),
 ('艺术', 10104)]

看来标签不平衡问题挺严重的。

由于时间问题,我们就不去做数据不平衡的相关处理了,感兴趣的同学可以去了解一下过采样、欠采样这里我们直接挑选几个类别出来尝试下,可以看出标签数量前五的类别中,童书、工业技术、大中专教材教辅三个类别的数据量相差不是很大,于是我们挑选出这三个类别来训练我们的三分类模型。

数据处理代码:

def extract_three_cls_data(data_path,save_path, txt_save_path):
    map_path = './base_fasttext/data/three_class/map.json'
    data = pd.read_csv(data_path, sep='\t')
    cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')]
    cls_data.index = range(len(cls_data))
    print(Counter(cls_data['label']))
    print('总共 {} 个类别'.format(len(np.unique(cls_data['label']))))
    label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))}
    label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3)
    if not os.path.exists(label_map_json):
        with open(map_path, 'w', encoding='utf-8') as f:
            f.write(label_map_json)
    cls_data['fasttext_label'] = cls_data['label'].map(label_map)
    for i in range(len(cls_data['fasttext_label'])):
        cls_data['fasttext_label'][i] = '__label__{}'.format(cls_data['fasttext_label'][i])
    print(len(cls_data))
    with open('./data/stopwords.txt', 'r', encoding='utf-8') as f:
        stopwords = f.readlines()
        stopwords = [i.strip() for i in stopwords]
    cls_data.to_csv(save_path, index=False)
    with open(txt_save_path, 'a+', encoding='utf-8') as f:
        for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)):
            words = row['text'].split(' ')
            out_str = ''
            for word in words:
                if word not in stopwords:
                    out_str += word
                    out_str += ' '
            row['text'] = out_str.encode('utf-8')

            line = str(row['text']) + '\t' + row['fasttext_label'] + '\n'
            f.write(line)

记得要做下停用词过滤,实验发现过滤停用词可以将准确率提高1%左右
注意一下这一行row['text'] = out_str.encode('utf-8'),在调试代码的过程中我发现,不加encode('utf-8'),生成的txtlen(data)不一致,但训练出来的结果是一样的,暂时没找到啥原因,加入之后就一样了。记得做predict的时候也需要对输入的string做下encode('utf-8')转换。

生成的txt格式:

b'\xe5\xa6\x88\xe5\xa6\x88 \xe6\xb2\xa1 \xe6\x83\xb3 \xe8\xbd\xaf\xe5\xbc\xb1 \xe5\x81\x9a \xe6\x9c\x80 \xe4\xbc\x98\xe7\xa7\x80   1 \xe4\xb8\xbb\xe9\xa2\x98\xe9\xb2\x9c\xe6\x98\x8e \xe7\xa7\xaf\xe6\x9e\x81\xe5\x90\x91\xe4\xb8\x8a \xe5\x85\x85\xe6\xbb\xa1 \xe6\xad\xa3 \xe8\x83\xbd\xe9\x87\x8f   2 \xe5\x85\xa8\xe5\xbd\xa9 \xe6\x8f\x92\xe5\x9b\xbe \xe7\xb2\xbe\xe7\xbe\x8e \xe6\x89\x8b\xe7\xbb\x98 \xe7\x8e\xaf\xe4\xbf\x9d \xe6\xb2\xb9\xe5\xa2\xa8 \xe5\x8d\xb0\xe5\x88\xb7   3 \xe5\x9f\xb9\xe5\x85\xbb \xe5\xad\xa9\xe5\xad\x90 \xe5\x9d\x9a\xe5\xbc\xba \xe6\x80\xa7\xe6\xa0\xbc \xe9\x94\xbb\xe7\x82\xbc \xe5\xad\xa9\xe5\xad\x90 \xe7\x8b\xac\xe7\xab\x8b \xe5\x93\x81\xe6\xa0\xbc   4 \xe6\x95\x99\xe4\xbc\x9a \xe5\xad\xa9\xe5\xad\x90 \xe8\xae\xa4\xe8\xaf\x86 \xe6\xbd\x9c\xe8\x83\xbd \xe6\xa0\x91\xe7\xab\x8b \xe5\xbc\xba\xe5\xa4\xa7 \xe8\x87\xaa\xe4\xbf\xa1\xe5\xbf\x83 '	__label__2
b'\xe6\x9c\xba\xe6\xa2\xb0\xe5\x88\xb6\xe9\x80\xa0 \xe5\xb7\xa5\xe8\x89\xba\xe5\xad\xa6 \xe6\x95\x99\xe6\x9d\x90 \xe7\xbc\x96\xe5\x86\x99 \xe8\xbf\x87\xe7\xa8\x8b \xe4\xb8\xad \xe5\x85\xb7\xe6\x9c\x89 \xe4\xbb\xa5\xe4\xb8\x8b \xe7\x89\xb9\xe8\x89\xb2   1 \xe8\xaf\xb7 \xe7\x90\x86\xe8\xae\xba \xe9\x87\x8d \xe5\xae\x9e\xe8\xb7\xb5   2 \xe4\xbc\x81\xe4\xb8\x9a \xe7\xae\xa1\xe7\x90\x86\xe4\xba\xba\xe5\x91\x98 \xe5\x90\x88\xe4\xbd\x9c \xe7\xbc\x96\xe5\x86\x99\xe6\x95\x99\xe6\x9d\x90 \xe7\xaa\x81\xe5\x87\xba \xe5\xb7\xa5\xe7\xa8\x8b \xe5\xae\x9e\xe4\xbe\x8b \xe5\x88\x86\xe6\x9e\x90 \xe8\xae\xb2\xe8\xa7\xa3   3 \xe8\xb4\xaf\xe5\xbd\xbb \xe5\x90\x8d\xe7\xa7\xb0 \xe6\x9c\xaf\xe8\xaf\xad \xe4\xbb\xa3\xe5\x8f\xb7 \xe9\x87\x8f \xe5\x8d\x95\xe4\xbd\x8d \xe7\x8e\xb0\xe8\xa1\x8c \xe5\x9b\xbd\xe5\xae\xb6\xe6\xa0\x87\xe5\x87\x86 '	__label__1

三、训练

训练代码还是比较简单的,直接将处理好的数据作为输入,再设置下参数,就可以了。

def train_three_class():
    train_data_path = './data/train.csv'
    train_csv_path = './base_fasttext/data/three_class/train.csv'
    train_txt_path = './base_fasttext/data/three_class/train.txt'
    if not os.path.exists(train_txt_path):
        extract_three_cls_data(train_data_path, train_csv_path, train_txt_path)
    test_data_path = './data/test.csv'
    test_csv_path = './base_fasttext/data/three_class/test.csv'
    test_txt_path = './base_fasttext/data/three_class/test.txt'
    if not os.path.exists(test_txt_path):
        extract_three_cls_data(test_data_path, test_csv_path, test_txt_path)
    dev_data_path = './data/dev.csv'
    dev_csv_path = './base_fasttext/data/three_class/dev.csv'
    dev_txt_path = './base_fasttext/data/three_class/dev.txt'
    if not os.path.exists(dev_txt_path):
        extract_three_cls_data(dev_data_path, dev_csv_path, dev_txt_path)
    # classifier = fasttext.train_supervised(input= train_txt_path, autotuneValidationFile = dev_txt_path)
    model_path = './base_fasttext/model/fasttext_three_class.pkl'
    if not os.path.exists(model_path):
        classifier = fasttext.train_supervised(train_txt_path,
                                                label="__label__",
                                                dim=100,
                                                epoch=10,
                                                lr=0.1,
                                                wordNgrams=3,
                                                loss='softmax',
                                                thread=8,
                                                verbose=True,
                                                minCount = 5)
        classifier.save_model(model_path)
        result = classifier.test(test_txt_path)
        print('F1 Score: {}'.format(result[1] * result[2] * 2 / (result[2] + result[1])))
    else:
        classifier = fasttext.load_model(model_path)
        # result = classifier.test(test_txt_path)
        # print('F1 Score: {}'.format(result[1] * result[2] * 2 / (result[2] + result[1])))
    return classifier

得益于分层Softmax,训练过程非常快。

F1 Score: 0.9315296251511487

还是相当不错的,拿个例子来测试一下:

three_classifier = train_three_class()
three_classifier_map_path = './base_fasttext/data/three_class/map.json'
with open(three_classifier_map_path, 'r', encoding='utf-8') as f:
   three_classifier_map = json.load(f)
true_class = '工业技术'
test_data = '通信 原理 - ( 第 3 版 )   本书 系统地 介绍 通信 的 基本概念 、 基本 理论 和 基本 分析方法 。 在 保持 一定 理论 深度 的 基础 上 , 本书 尽可能 简化 数学分析 过程 , 突出 对 概念 、 新 技术 的 介绍 ; 叙述 上 力求 概念 清楚 、 重点 突出 、 深入浅出 、 通俗易懂 ; 内容 上 力求 科学性 、 先进性 、 系统性 与 实用性 的 统一 。   本书 共 10 章 , 内容 包括 : 绪论 、 信号 与 噪声 分析 、 模拟 调制 系统 、 模拟信号 的 数字传输 、 数字信号 的 基带 传输 、 数字信号 的 载波 传输 、 现代 数字 调制 技术 、 信道 、 信道编码 和 扩频通信 。 内容 涵盖 国内 通信 原理 教学 的'.encode('utf-8')
result = three_classifier.predict(str(test_data))[0][0]
predicT_class = list(three_classifier_map.keys())[list(three_classifier_map.values()).index(int(result[-1]))]
print('预测的类别为:{}'.format(predicT_class))
print('真实的类别为:{}'.format(true_class))

记得测试用例要用encode('utf-8')转换一下

预测的类别为:工业技术
真实的类别为:工业技术

预测正确。

关于调参: 也可以使用fasttext的自动寻参来训练,但是太慢了,五六分钟还没搞定,于是我放弃了。

classifier = fasttext.train_supervised(input= train_txt_path, autotuneValidationFile = dev_txt_path)

三分类的准确率达到了93%,效果相当不错,那么在35个类别上的效果怎么样呢?
于是,我用所有数据测试了一下:

F1 Score: 0.755784146181149
预测的类别为:励志与成功
真实的类别为:工业技术

35个类别有75%的准确率,整体效果还不错。
由于存在严重的数据不平衡问题,在单一类别的准确率应该翻车了,这里就不再测试了。感兴趣的同学可以自己测试后在评论区留言。

总结

1、fasttext在文本分类任务上效果确实很不错
2、fasttext采用层次化 softmax使其训练速度非常快

本文所有代码及数据Github链接:fasttext_classify
相关文章:TextCNN文本分类Pytorch

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ToTensor

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值