Fasttext文本分类

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_20989105/article/details/82255434

一、简介

###1、简介

fasttext是facebook开源的一个词向量与文本分类工具,在2016年开源,典型应用场景是“带监督的文本分类问题”。提供简单而高效的文本分类和表征学习的方法,性能比肩深度学习而且速度更快。

fastText结合了自然语言处理和机器学习中最成功的理念。这些包括了使用词袋以及n-gram袋表征语句,还有使用子字(subword)信息,并通过隐藏表征在类别间共享信息。我们另外采用了一个softmax层级(利用了类别不均衡分布的优势)来加速运算过程。

2、原理

fastText方法包含三部分,模型架构,层次SoftMax和N-gram特征。

  • 模型架构

    fastText的架构和word2vec中的CBOW的架构类似,因为它们的作者都是Facebook的科学家Tomas Mikolov,而且确实fastText也算是words2vec所衍生出来的。

    fastText 模型输入一个词的序列(一段文本或者一句话),输出这个词序列属于不同类别的概率。

    序列中的词和词组组成特征向量,特征向量通过线性变换映射到中间层,中间层再映射到标签。

    fastText 在预测标签时使用了非线性激活函数,但在中间层不使用非线性激活函数。fastText 模型架构和 Word2Vec 中的 CBOW 模型很类似。不同之处在于,fastText 预测标签,而 CBOW 模型预测中间词。

  • 层次SoftMax

    对于有大量类别的数据集,fastText使用了一个分层分类器(而非扁平式架构)。不同的类别被整合进树形结构中(想象下二叉树而非 list)。在某些文本分类任务中类别很多,计算线性分类器的复杂度高。为了改善运行时间,fastText 模型使用了层次 Softmax 技巧。层次 Softmax 技巧建立在哈弗曼编码的基础上,对标签进行编码,能够极大地缩小模型预测目标的数量。

    fastText 也利用了类别(class)不均衡这个事实(一些类别出现次数比其他的更多),通过使用 Huffman 算法建立用于表征类别的树形结构。因此,频繁出现类别的树形结构的深度要比不频繁出现类别的树形结构的深度要小,这也使得进一步的计算效率更高。

  • N-gram特征

    fastText 可以用于文本分类和句子分类。不管是文本分类还是句子分类,我们常用的特征是词袋模型。但词袋模型不能考虑词之间的顺序,因此 fastText 还加入了 N-gram 特征。“我 爱 她” 这句话中的词袋模型特征是 “我”,“爱”, “她”。这些特征和句子 “她 爱 我” 的特征是一样的。如果加入 2-Ngram,第一句话的特征还有 “我-爱” 和 “爱-她”,这两句话 “我 爱 她” 和 “她 爱 我” 就能区别开来了。当然啦,为了提高效率,我们需要过滤掉低频的 N-gram。

二、安装

1、使用make构建fastText(首选)

$ wget https://github.com/facebookresearch/fastText/archive/v0.1.0.zip
$ unzip v0.1.0.zip
$ cd fastText-0.1.0
$ make

这个方法下载是在fastTest的最新稳定版本。

2、使用cmake构建fastText

$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ mkdir build && cd build && cmake ..
$ make && make install

这将创建fasttext二进制文件以及所有相关库(共享,静态,PIC),这个下载是最新的主干版本。

3、为Python构建fastText

$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .

##三、语料准备

语料准备绝对是在文本分类里面占大头的重头戏,比较花精力。

这里测试用了舆情数据的标题数据,数据量如下

训练数据 测试数据 分类数
1.4w 3.5k 16类

语料的处理我这边是用java进行处理的,毕竟java才是最熟悉的。。

下面是语料的格式,我的语料是通过分词的,分词是使用HanLP的分词实现的:

申能股份 启动 回购   __label__104003
汤臣倍健 海外 并购 疑点 难 消 标的 公司 虚增 利润 嫌疑   __label__104001
现金流 硬 约束 大面积 回购 潮 难 现   __label__104003
美都能源 出售 房地产 业务 资产 集 中发展 新能源车 产业   __label__104001
四通股份 1.69亿 限售股 月 日 上市 流通   __label__104004
赢合科技 两 名 董事 名 监事 辞职   __label__103006
上海建工 部分 高 合计 增持 公司 125万 股   __label__104001
中海油 原 副 总经理 李凡荣 调任 国家能源局 副局长   __label__103006
银信科技 7920万 股 日 解禁   __label__104004
新潮实业 两 高 工作 原因 辞职   __label__103006
中国 债市 上证所 高 评级 可转债 参与 质押 式 回购 助 增 流动性   __label__104008
重庆啤酒 总 会计师 总经理 助理 辞职   __label__103006
重磅 观 汽车 官方 确认 决定 引进 新的 战略 投资者   __label__104006
振华重工 控股 股东 变更 事宜 获 批   __label__104007
内蒙华电 控股 股东 方面 完成 增持 计划   __label__104001
......

每一行是一组数据,前一部分是经过分词后的标题数据,后面是固定格式__label__加上所属分类的id,两组之间用3个空格分割。

四、用例示例

FastText主要可以用于单词表示学习和文本分类。单词表示学习说白了就是词向量模型的训练,我们着重介绍文本分类。

1、采用命令行的方式

FastText提供命令行的方式来进行训练模型和分类,主要使用的命令有训练、评估模型、对数据分类。

  • 训练
$ ./fasttext supervised -input train.txt -output model

./fasttext:这个即为在2.1中使用make构建fastText的目录

supervised:表示监督学习,即训练数据

-input train.txt:表示输入文件为train.txt,即是我们训练的语料文件,每行包含一个训练语句以及标签。默 认情况下,标签是以字符串为前缀的单词__label__

-output model:表示输出的模型路径

训练操作这将输出两个文件:model.binmodel.vec

  • 评估模型

训练模型后,您可以通过计算精度并在测试集上以k(P @ k和R @ k)调用来评估它:

$ ./fasttext test model.bin test.txt k

test.txt是我们的测试文件,参数k是可选的,默认情况下等于1,例如为了获得一段文本的10个最可能的标签,可使用:

$ ./fasttext predict model.bin test.txt 10
  • 获取概率

    predict-prob用来获得每个标签的概率

$ ./fasttext predict-prob model.bin test.txt k

其中test.txt包含一行文本以按行分类。这样做会向标准输出打印每行最有可能的k个标签。参数k是可选的,默认情况下等于1

  • 量化模型

可以使用以下命令量化受监控的模型以减少其内存使用量:

$ ./fasttext quantize -output model

这将创建一个.ftz内存占用较小的文件。所有操作和model.bin的模型一样,例如评估模型:

$ ./fasttext test model.ftz test.txt

实际操作一下

//训练
$ ./fastText-0.1.0/fasttext supervised -input 0821-title/train.txt -output model
Read 0M words
Number of words:  17352
Number of labels: 15
Progress: 100.0%  words/sec/thread: 524129  lr: 0.000000  loss: 0.191040  eta: 0h0m 
//评估
$ ./fastText-0.1.0/fasttext test ./model.bin ./0821-title/test.txt
N       0
P@1     -nan
R@1     -nan
Number of examples: 0

2、采用python的方式

编写了python脚本,评估使用了sklean的文本分类的评估方法,直接上代码,简单干脆,具体的都已经注释在脚本中。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug  1 11:12:10 2018
@author: chenyang
"""
#这里是导入fasttext库
import fasttext
import os
import sys
#这里是导入sklearn库
from sklearn import metrics

#这里是为了中文乱码做的一些转码
if sys.version_info[0] > 2:
    is_py3 = True
else:
    is_py3 = False

def native_content(content):
    if not is_py3:
        return content.decode('utf-8')
    else:
        return content
        
#这个方法是读取文件,主要是训练集和测试集
def open_file(filename, mode='r'):
    if is_py3:
        return open(filename, mode, encoding='utf-8', errors='ignore')
    else:
        return open(filename, mode)

#这个方法是切割数据,以3个空格进行分割,返回内容和对应标签的2个list
def read_file(filename):
    """读取文件数据"""
    contents, labels = [], []
    with open_file(filename) as f:
        for line in f:
            try:
                content,label = line.strip().split('   ')
                if content:
                    contents.append(native_content(content))
                    labels.append(native_content(label))
            except:
                pass
    return contents, labels

#主方法
if __name__ == '__main__':
    #判断输入参数是否含有语料路径的参数
    if len(sys.argv) != 2:
        print("The number of parameters is not correct!")
        exit()

    filename=sys.argv[1]
    print("input param:%s" % filename)
    print(os.path.exists('./%s/model.bin' % filename))
    
    classifier = None
    #判断是否已经存在模型,如果存在则加载,不存在则进行训练
    if(os.path.exists('./%s/model.bin' % filename)):
        classifier =fasttext.load_model('./%s/model.bin' % filename)
    else:
        #训练模型
        if(not os.path.exists('./%s/' % filename)):
            os.mkdir('./%s/' % filename)
        fasttext.supervised('../%s/train.txt' % filename,'./%s/model' % filename)
        classifier =fasttext.load_model('./%s/model.bin' % filename)
    
    #读取测试数据
    lines, test_cls = read_file("../%s/test.txt" % filename);
    
    print("data sum: ",len(lines))
    #获取测试数据的分类结果
    pred = classifier.predict(lines)
    
    pred_cls = []
    for x in pred:
        pred_cls.append(x[0])
    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(test_cls, pred_cls))
    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(test_cls, pred_cls)
    print(cm)

调用脚本方式:

//进入脚本所在目录
$ python fasttext-train.py 0820-title

0820-title:为语料所在的路径名

##五、附预料处理代码:

/**
 * @description: 构建训练语料
 * @author: chenyang
 * @create: 2018-06-08
 **/
public class BuildCorpus {

    //训练语料所在路径
    private final static String PATH = "C:\\Users\\chenyang\\Desktop\\0821";
    //生成的训练文件
    private final static String TRAIN = "C:\\Users\\chenyang\\Desktop\\train.txt";
    //生成的测试文件
    private final static String TEST = "C:\\Users\\chenyang\\Desktop\\test.txt";

    public static void main(String[] args) {
        try {
            exportCurpor();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void exportCurpor() throws Exception {
        StopWatch watch = new StopWatch();
        watch.start();
        Path p = Paths.get(PATH);

        List<String> list = new ArrayList<>();

        Files.walkFileTree(p, new FileVisitor<Path>() {

            @Override
            public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) throws IOException {
                System.out.println("遍历文件夹:" + dir);
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
                System.out.println("遍历文件:" + file.getFileName());
                Stream<String> s = Files.lines(file);
                List<String> rel = s.map(l -> HanLP.newSegment().enableCustomDictionaryForcing(true)
                        .enablePlaceRecognize(true)
                        .enableOrganizationRecognize(true).seg(l.split("XXXXXX")[4]).stream().map(t -> t.word)
                        .filter(w -> NormalizeUtils.checkStringContainChinese(w)) //必须要有中文
                        .filter(w -> !CoreStopWordDictionary.contains(w))
                        .filter(w -> !w.matches(
                                "^[\\pP|\\pS|\\p{Zs}]+|"
                                        + "[\\-|\\+|\\.|\\,]+$")) //过滤掉数字和标点
                        .collect(Collectors.joining(" "))+"   __label__"+l.split("XXXXXX")[1]
                ).collect(Collectors.toList());
                for (String tool : rel){
                    list.add(tool);
                }
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult visitFileFailed(Path file, IOException exc) throws IOException {
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
                return FileVisitResult.CONTINUE;
            }
        });

        Path target = Paths.get(TRAIN);
        FileOutputStream fout = new FileOutputStream(target.toFile(), false);
        OutputStreamWriter out = new OutputStreamWriter(fout, "UTF-8");
        BufferedWriter bw = new BufferedWriter(out);

        Path test = Paths.get(TEST);
        FileOutputStream fout1 = new FileOutputStream(test.toFile(), false);
        OutputStreamWriter out1 = new OutputStreamWriter(fout1, "UTF-8");
        BufferedWriter bw1 = new BufferedWriter(out1);

        Collections.shuffle(list);

        for (int i=0;i<list.size();i++){
            if (i%10==0 || i%10==3 || i%10==7) {
                bw1.write(list.get(i));
                bw1.newLine();
            }else {
                bw.write(list.get(i));
                bw.newLine();
            }
        }

        bw1.close();
        bw.close();
        out.close();
        fout.close();
        out1.close();
        fout1.close();

        watch.stop();

        System.out.println(watch.toString());
    }
}

github地址:https://github.com/cy576013581/text-classification

展开阅读全文

没有更多推荐了,返回首页