无痛接入FastText算法进行文本分类(附代码)

AI应用开发相关目录

本专栏包括AI应用开发相关内容分享,包括不限于AI算法部署实施细节、AI应用后端分析服务相关概念及开发技巧、AI应用后端应用服务相关概念及开发技巧、AI应用前端实现路径及开发技巧
适用于具备一定算法及Python使用基础的人群

  1. AI应用开发流程概述
  2. Visual Studio Code及Remote Development插件远程开发
  3. git开源项目的一些问题及镜像解决办法
  4. python实现UDP报文通信
  5. python实现日志生成及定期清理
  6. Linux终端命令Screen常见用法
  7. python实现redis数据存储
  8. python字符串转字典
  9. python实现文本向量化及文本相似度计算
  10. python对MySQL数据的常见使用
  11. 一文总结python的异常数据处理示例
  12. 基于selenium和bs4的通用数据采集技术(附代码)
  13. 基于python的知识图谱技术
  14. 一文理清python学习路径
  15. Linux、Git、Docker常用指令
  16. linux和windows系统下的python环境迁移
  17. linux下python服务定时(自)启动
  18. windows下基于python语言的TTS开发
  19. python opencv实现图像分割
  20. python使用API实现word文档翻译
  21. yolo-world:”目标检测届大模型“
  22. 爬虫进阶:多线程爬虫
  23. python使用modbustcp协议与PLC进行简单通信
  24. ChatTTS:开源语音合成项目
  25. sqlite性能考量及使用(附可视化操作软件)
  26. 拓扑数据的关键点识别算法
  27. python脚本将视频抽帧为图像数据集
  28. 图文RAG组件:360LayoutAnalysis中文论文及研报图像分析
  29. Ubuntu服务器的GitLab部署
  30. 无痛接入图像生成风格迁移能力:GAN生成对抗网络
  31. 一文理清OCR的前世今生
  32. labelme使用笔记
  33. HAC-TextRank算法进行关键语句提取
  34. Segment any Text:优质文本分割是高质量RAG的必由之路
  35. 无痛接入FastText算法进行文本分类


简介

FastText的特点如下:
速度:FastText的设计初衷就是为了高效,它的训练速度比许多其他文本处理工具快得多。
简单:FastText的模型结构相对简单,易于理解和实现。
准确性:尽管模型简单,但FastText在许多文本分类任务中都能达到与其他复杂模型相媲美的准确度。
多功能性:除了文本分类,FastText还可以用于词嵌入的生成,它可以捕捉到词的语义信息,比如相似的词在嵌入空间中会彼此接近。
支持多语言:FastText能够处理多种语言的文本,这对于跨语言文本分类任务非常有用。
无需大量数据:对于一些小语种或者数据稀缺的场景,FastText也能够有效工作。

FastText的核心思想是将文本数据转换成向量表示,然后使用这些向量进行分类或相似度计算。它使用了一种层次化的softmax技术来加速训练过程,并采用了负采样来改善分类性能。

在文本分类任务中,FastText将文本转换成一系列的n-gram,然后通过模型学习每个n-gram的权重,最后将这些权重组合起来,形成整个文本的向量表示。这个向量随后被送入一个softmax层进行分类。

此外,文本分类算法在大模型领域也具有一定的应用前景:

训练FastText分类器,从大量领域不可知数据中识别领域内数据。具体来说,为了训练FastText分类器,选择了一定数量的领域内数据作为正面样本,并选择了同等数量的领域外数据作为负面样本。训练好的二元分类器随后被用来从通用语料库(例如,网络语料库)中选择领域内数据;其次,应用过滤器,以确保领域内数据(包括原始的领域内语料库和选择的数据)具有高教育价值。通过这种方式,可以提高过滤后的领域内数据的质量,进而提高模型的性能。

数据集情况

停用词集:
在这里插入图片描述
原始数据集:
在这里插入图片描述

环境部署:

https://pypi.org/project/fasttext-wheel/#files

下载对应版本的fasttext.whl文件,直接pip会出现轮子文件构建失败的现象。
此外下载一下处理繁体文本的文字:

pip install langconv -i https://pypi.tuna.tsinghua.edu.cn/simple

代码及使用

项目文件夹总体情况:
在这里插入图片描述

数据清理文件text_cleaner.py:

# -*- coding: utf-8 -*-

from types import MethodType, FunctionType
import jieba
import re
# 导入用于繁体/简体转换的包
from langconv import *


def clean_txt(raw):
    fil = re.compile(r"[^0-9a-zA-Z\u4e00-\u9fa5]+")
    return fil.sub(' ', raw)


def seg(sentence, sw, apply=None, cut_all=False):
    """
    对中文文本去特殊符号、去停用词、分词
    :param sentence: 原始中文文本
    :param sw:
    :param apply:
    :param cut_all:
    :return: 分词后中文文本
    """
    if isinstance(apply, FunctionType) or isinstance(apply, MethodType):
        sentence = apply(sentence)
    return ' '.join([i for i in jieba.cut(sentence, cut_all=cut_all) if i.strip() and i not in sw])


def stop_words():
    with open('stopwords.txt', 'r', encoding='utf-8') as swf:
        stopwords = [i.strip() for i in swf.readlines()]
    return stopwords


def cht_to_chs(line):
    """
    中文繁体文本转简体
    :param line: 原始文本
    :return: 中文简体文本
    """
    line = Converter('zh-hans').convert(line)
    line.encode('utf-8')
    return line


def replace_text(input_str, str_targ, str_rep):
    if isinstance(input_str, list):
        return [replace_text(s, str_targ, str_rep) for s in input_str]
    return input_str.replace(str_targ, str_rep)


# 对某个sentence进行处理:
if __name__ == '__main__':
    content = '海尔(Haier)新风机 室内外空气交换 恒氧新风机 XG-100QH/AA'
    res = seg(content.lower().replace('\n', ''), stop_words(), apply=clean_txt)
    print(res)
    test = stop_words()
    print(test)

数据预处理步骤,生成中间csv文档,训练集文档,验证集文档:

# -*- coding: utf-8 -*-


import pandas as pd
import numpy as np
import random
from text_cleaner import *
from tqdm import tqdm
import os
import re
from sklearn.utils import shuffle

def load_df(file_path, encoding='utf-8', drop_dup=True, drop_na=True):
    """
    从csv文件读取dataframe
    :param file_path: csv文件路径
    :param encoding: 编码,默认 UTF-8
    :param drop_dup: 去掉重复行
    :param drop_na: 去掉空行
    :return: dataframe
    """
    df = pd.read_csv(file_path, encoding=encoding, engine='python')
    if drop_dup:
        df = df.drop_duplicates()
    if drop_na:
        df = df.dropna()
    return df


def write_txt(file_name, df_data, delimiter=' ', fmt="%s", encoding='utf-8'):
    """
    把dataframe写入txt,用于DF转fasttext训练集
    :param delimiter:
    :param file_name: 写入的txt文件路径
    :param df_data: <label> <text>型的DataFrame
    :param fmt: 格式,默认为字符串
    :param encoding: 编码,默认为 UTF-8
    :return:
    """

    np.savetxt(file_name, df_data.values, delimiter=delimiter, fmt=fmt, encoding=encoding)


def dataframe_split(df_text, train_ratio):
    """
    将dataframe按比例分割
    :param df_text: 原始dataframe
    :param train_ratio: 训练集占比
    :return: 训练集和验证集
    """
    df_text = shuffle(df_text)
    train_set_size = int(len(df_text) * train_ratio)
    valid_set_size = int(len(df_text) * (1 - train_ratio))
    df_train_data = df_text[:train_set_size]
    df_valid_data = df_text[train_set_size:(train_set_size + valid_set_size)]
    return df_train_data, df_valid_data


def count_diff_in_col(df_text, col_name):
    """
    统计某一列不同种类的个数
    :param df_text: dataframe
    :param col_name: 需要统计的列名
    :return: 一个字典
    """
    col_set = set(df_text[col_name].values)
    col_list = list(df_text[col_name].values)
    compute = dict()
    for item in col_set:
        compute.update({item: col_list.count(item)})
    return dict(sorted(compute.items()))


def drop_rows_where_col_has(dataframe, col_name, target):
    """
    删除 dataframe中, col_name列包含target的行
    :param dataframe:
    :param col_name:
    :param target:
    :return: 新的dataframe
    """
    return dataframe.drop(dataframe[dataframe[col_name] == target].index)


def df_data_augmentation(dataframe, col_label='label', col_text='text', num_sample=50, sample_length=18):
    """
    将每一类标签的样本扩充至指定数量
    :param dataframe:
    :param col_label:
    :param col_text:
    :param num_sample: 扩充后每个种类样本的数量,默认50
    :param sample_length: 样本文本的长度, 默认18
    :return: 返回扩充后的dataframe 和 记录不同标签样本的字典
    """
    dict_tmp = count_diff_in_col(dataframe, col_label)
    df_sample = dataframe.copy(deep=True)
    for key in list(dict_tmp.keys()):
        if dict_tmp[key] < num_sample:
            df_tmp = df_sample[(df_sample[col_label] == key)]
            list_text = []
            for text in df_tmp[col_text].values.tolist():
                list_text.extend(text.split())
            while dict_tmp[key] < num_sample:
                str_tmp = ' '.join(random.sample(list_text, sample_length))
                df_sample = df_sample.append({col_label: key, col_text: str_tmp}, ignore_index=True)
                dict_tmp.update({key: dict_tmp[key] + 1})
    return df_sample, dict_tmp


def repalce_df_text(dataframe, col_name, str_targ, str_rep):
    """
    将dataframe中某一列的字符串中 str_tage 替换为 str_rep
    :param dataframe:
    :param col_name:
    :param str_targ:
    :param str_rep:
    :return:
    """
    li0 = dataframe[col_name].values.tolist()
    li1 = replace_text(li0, str_targ, str_rep)
    if len(li1) == len(li0):
        dataframe[col_name] = li1
        return dataframe
    else:
        print('Lenghth of dataframe has been changed !')
        return -1


def df_cut_ch(dataframe, col_name, save_path=''):
    """
    对 dataframe的col_name列中的中文文本分词, 默认cut_all
    :param dataframe:
    :param col_name:
    :param save_path:保存路径,默认不保存
    :return:
    """
    df_cut = dataframe.copy(deep=True)
    text_cut = []
    stopwords = stop_words()

    for text in tqdm(dataframe[col_name].astype(str)):
        datacutwords = ' '.join([i for i in jieba.cut(text) if i.strip() and i not in stopwords])
        text_cut.append(datacutwords)

    del df_cut[col_name]
    df_cut[col_name] = text_cut
    if len(save_path):
        df_cut.to_csv(save_path, encoding='utf_8_sig', index=False)
    return df_cut

# 头条数据抽取方法
def getdata(filepath):
    data = pd.read_table(filepath, encoding='utf-8', sep='\n', header=None)
    pattern = re.compile('(\d+)_!_(.*?)_!_(.*?)_!_(.*?)_!_([\s\S]*)')
    datanew = data[0].str.extract(pattern)
    columsdic = {
        0:'id'
        ,1:'num'
        ,2:'catgor'
        ,3:'content'
        ,4:'keys'
    }
    datanew.rename(columns=columsdic, inplace=True)
    datanew['text'] = datanew['content'] + datanew['keys']
    datanew['pre'] = '__label__'
    datanew['labels'] = datanew['pre']+datanew['catgor']
    datanew = datanew.loc[:, ['labels', 'text']]
    return datanew

if __name__ == '__main__':

    basepath = os.getcwd()
    file = r'toutiao_cat_data.txt'
    filepath = os.path.join(basepath, file)
    print(filepath)
    # print(getdata(filepath))
    cutfile = r'cutdata.csv'
    cutfilepath = os.path.join(basepath, cutfile)
    datanotcut = getdata(filepath)
    datacut = df_cut_ch(dataframe=datanotcut, col_name='text', save_path=cutfilepath)

    data = pd.read_csv(cutfilepath, encoding='utf-8')
    data = shuffle(data)
    df_train_data, df_valid_data = dataframe_split(df_text=data, train_ratio=0.7)
    write_txt(file_name=r'df_train_data.txt', df_data=df_train_data, encoding='utf-8')
    write_txt(file_name=r'df_valid_data.txt', df_data=df_valid_data, encoding='utf-8')

在这里插入图片描述
在这里插入图片描述
模型训练及测试:

# -*- coding: utf-8 -*-

import fasttext
import pandas as pd
from sklearn.metrics import classification_report
import os
import time
import re

report_index = 0


def train_model(train_file, dim=100, epoch=100, lr=0.5, loss='softmax', wordNgrams=2, save_dir=''):
    """
    训练fasttext模型并保存在 save_dir 文件夹, 详细参数参阅
    https://fasttext.cc/docs/en/python-module.html#train_supervised-parameters
    :param train_file: 训练数据文件
    :param dim: 词向量大小, 默认100
    :param epoch: 默认100
    :param lr: 学习率, 默认0.5
    :param loss: 损失函数,默认softmax, 多分类问题推荐 ova
    :param wordNgrams: 默认2
    :param save_dir: 模型保存文件夹,默认不保存
    :return: 文本分类器模型
    """
    classifier = fasttext.train_supervised(train_file, label='__label__', dim=dim, epoch=epoch,
                                           lr=lr, wordNgrams=wordNgrams, loss=loss)
    if len(save_dir):
        # model_name = f'model_dim{str(dim)}_epoch{str(epoch)}_lr{str(lr)}_loss{str(loss)}' \
        #              f'_ngram{str(wordNgrams)}_{str(report_index)}.model'
        # if not os.path.exists(save_dir):
        #     os.mkdir(save_dir)
        # classifier.save_model(os.path.join(save_dir, model_name))
        classifier.save_model(savepath)
    return classifier


def give_classification_report(classifier, valid_csv, col_label=0, col_text=1, report_file=''):
    """
    使用 classification_report 验证 fasttext 模型分类效果,需在_FastText 类中添加dict_args()属性
    :param classifier: fasttext文本分类模型
    :param valid_csv: 验证数据集,需要csv格式
    :param col_label: 标签列名,默认 'label'
    :param col_text: 文本列名, 默认'text'
    :param report_file: 保存report文件名,默认不保存
    :return: classification report
    """
    df_valid = pd.read_table(valid_csv, sep='\n', header=None)
    print(df_valid)
    pattern = re.compile('(.*?) ([\s\S]*)')
    datause = df_valid[0].str.extract(pattern)
    print(classifier.predict(str(datause.iloc[0,1]))[0][0])
    alluse = []
    for i in range(datause.shape[0]):
        predict = classifier.predict(str(datause.iloc[i,1]))[0][0]
        alluse.append(predict)
    datause["predicted"] = alluse
    print(datause)
    tags = list(set(datause[0]))
    report = classification_report(datause[0].tolist()
                                   , datause["predicted"].tolist()
                                   , target_names=tags)
    return report
train_file = r'df_train_data.txt'
savepath = r'model.bin'
rainmodel = train_model(train_file=train_file, save_dir=savepath)
modeluse = fasttext.load_model(r'model.bin')
valid_csv = r'df_valid_data.txt'
report_file = r'report_file.txt'
report = give_classification_report(classifier=modeluse
                           , valid_csv=valid_csv
                           , col_label= 0,
                           col_text =1,
                           report_file = '')
print(report)
precision    recall  f1-score   support

          __label__news_edu       0.90      0.89      0.90      5835
      __label__news_culture       0.95      0.93      0.94     10658
       __label__news_travel       0.88      0.91      0.89      8519
        __label__news_world       0.92      0.92      0.92      7961
      __label__news_finance       0.91      0.92      0.92     11853
        __label__news_house       0.85      0.84      0.84      8213
         __label__news_tech       0.94      0.94      0.94      8749
          __label__news_car       0.94      0.93      0.93      5324
       __label__news_sports       0.90      0.89      0.89      7556
        __label__news_story       0.97      0.96      0.96     11319
         __label__news_game       0.88      0.77      0.82      1880
             __label__stock       0.88      0.90      0.89     12442
  __label__news_agriculture       0.85      0.87      0.86      6390
     __label__news_military       0.84      0.86      0.85      7993
__label__news_entertainment       0.35      0.06      0.10       114

                   accuracy                           0.90    114806
                  macro avg       0.86      0.84      0.84    114806
               weighted avg       0.90      0.90      0.90    114806

其中数据集可参考:https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset

  • 13
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

写代码的中青年

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值