FastText实战天池新闻文本分类比赛

FastText的精髓在于将整篇文档的词及n-gram向量叠加平均得到文档向量,然后使用文档向量做softmax多分类。

上面是业界大佬们对FastText模型的高度总结,看起来是不是特别简单?然而,模型的输入到底是什么?为什么要引入n-gram向量?文档的词以及n-gram向量是怎么叠加的?为什么它训练速度快并且性能很好?这些问题一直困扰着我,今天就来深扒一下。

 一、理论基础

FastText是facebook开源的一个快速文本分类器,在提供简单高效的文本分类和表征学习方法的同时,性能比肩深度学习并且训练速度非常快,往往可以作为文本分类场景下的baseline。其模型非常简单,和word2vec的cbow模型很相似,不同点在于cbow预测的是中心词,而fasttext预测的是文本标签。

上图是论文中模型的架构,其模型输入 = 句子本身 + n-gram额外特征。举个例子:我喜欢她。我们对这句话分词后得到:我,喜欢,她。其对应的bi-gram特征为:我喜欢,喜欢他。那么模型的输入变为:我,喜欢,她, 我喜欢,喜欢她。这样做一方面引入了更多字符特征,另一方面解决了词顺序的问题,毕竟我喜欢她和她喜欢我还不是同一个意思。Tri-gram以此类推。

由于每个句子的长度不同,为了便于建模,需要把每个句子填充到相同的长度。如果只使用常见的10000的词(数据预处理后),我们需要把它映射到2-10002之间的整数索引,其中1是留给未登录词,0是用来填充长度。此外,如果数据量很大且句子较长,会引起n-gram数据组合爆炸的问题,原论文中通过采用hash映射到1到K之间来解决这个问题(具体见下面的模型优化章节),同时为了避免和前面的汉字索引出现冲突,哈希映射值一般会加上最大长度的值,即映射到10003-10003+k之间。所以模型的输入长度 = 句子填充长度 + 哈希映射值 = 10003 + k。

数据转换索引后,模型会经过Embedding层,将索引映射为稠密向量,那么模型是如何求平均的呢?这里参考Keras官方实现的fasttext的文本分类文档:

model = Sequential()

# 我们从有效的嵌入层开始,该层将 vocab 索引映射到 embedding_dims 维度
model.add(Embedding(max_features,
                    embedding_dims,
                    input_length=maxlen))

# 我们添加了 GlobalAveragePooling1D,它将对文档中所有单词执行平均嵌入
model.add(GlobalAveragePooling1D())

# 我们投影到单个单位输出层上,并用 sigmoid 压扁它:
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

这里采用了GlobalAveragePooling1D()来对所有文档执行嵌入,它会把所有词的Embedding的向量求平均得到一个向量。即把我,喜欢,她,我喜欢,喜欢她这5个词的Embedding求和再除以5,得到均值向量,然后跟上输出层。

在模型优化上,作者主要采用了两种加速训练的方法:

1、层次softmax: 这个并不是新颖的技术,主要是用来解决文本类别比较多时,使用softmax计算会使计算复杂度很高,大概是O(kh),其中k是文本类别树,h是embedding维度。而层次softmax采用哈夫曼树的方式来把复杂度降低到O(hlog2k),大大加快了训练速度。具体细节不铺开讨论。

2、其次,为了节省内存和n-gram组合爆炸的问题,fasttext把n-gram数据构成一个词典,并通过哈希函数映成整数(索引)至1到K,理论上哈希到同一个位置的不同n-gram应该共享索引和Embedding(待验证)

二、代码实战

1、安装fasttext:

git clone https://github.com/facebookresearch/fastText.git
cd fastText
pip install .

如果安装不上,可以用gensim包中的fasttext。

2、数据简介:

以天池新闻文本分类数据集为数据集,其中训练集有20万条样本,测试集有AB两个版本各5万条样本,为避免选手自行打标,对数据按字符进行了匿名处理,可以理解为把汉字转化成了索引,样本如下图所示:

数据匿名处理其实帮我们省略了很多数据预处理的工作,比如说去除标点符号、去除停用词等等,但是可能也难以达到特别好的精度。

数据统计总共有14个类别,是个典型的文本多分类问题,评估指标为f1-score 。

3、模型参数:

input             # training file path (required)
lr                # learning rate [0.1]
dim               # size of word vectors [100]
ws                # size of the context window [5]
epoch             # number of epochs [5]
minCount          # minimal number of word occurences [1]
minCountLabel     # minimal number of label occurences [1]
minn              # min length of char ngram [0]
maxn              # max length of char ngram [0]
neg               # number of negatives sampled [5]
wordNgrams        # max length of word ngram [1]
loss              # loss function {ns, hs, softmax, ova} [softmax]
bucket            # number of buckets [2000000]
thread            # number of threads [number of cpus]
lrUpdateRate      # change the rate of updates for the learning rate [100]
t                 # sampling threshold [0.0001]
label             # label prefix ['__label__']
verbose           # verbose [2]
pretrainedVectors

4、完整代码

import fasttext
import pandas as pd
from sklearn.utils import shuffle


class DataProcess(object):

    def load_data(self):
        df_train = pd.read_csv('train_set.csv', sep='\t')

        # 对类别加上 "__label__"前缀
        df_train['label_ft'] = '__label__' + df_train['label'].astype(str)

        df_train[['text', 'label_ft']].iloc[:195000].to_csv('train.csv', index=None, header=None, sep='\t')

        return df_train

    def split_data(self, df_train):
        # 打乱数据集
        df_train = shuffle(df_train)

        # 训练集
        train_data = df_train[['text', 'label_ft']].iloc[:195000]
        train_data.to_csv('train.csv', index=None, header=None, sep='\t')

        # 挑选5000条数据作为验证集
        validate_data = df_train[['text', 'label_ft']].iloc[-5000:]
        validate_data.to_csv('validate.csv', index=None, header=None, sep='\t')


class FastTextModel(object):

    def __init__(self, ):
        pass

    def train(self):
        model = fasttext.train_supervised(input='train.csv',
                                          label_prefix="__label__",
                                          epoch=30,
                                          dim=32,
                                          lr=0.1,
                                          loss='softmax',
                                          word_ngrams=3,
                                          min_count=2,
                                          bucket=1000000)

        return model

    def save_model(self, model):
        model.save_model("fasttext.bin")

    def load_model(self):
        model = fasttext.load_model("fasttext.bin")

        return model

    # 预测验证集结果
    def test(self):
        model = self.load_model()
        score = model.test("validate.csv")
        precision = score[1]
        recall = score[2]
        f1_score = round(2 * (precision * recall) / (precision + recall), 2)

        print("验证集评测结果:Precision:{}, Recall:{}, F1-score:{}".format(precision, recall, f1_score))

    # 预测5万条测试集A的结果,或者测试集B的结果提交
    def predict_testA(self):
        df_testA = pd.read_csv("test_a.csv")
        test_data = df_testA["text"].values.tolist()

        model = self.load_model()
        res = model.predict(test_data)

        predict_res = [y_[0].replace("__label__", "") for y_ in res[0]]
        print(predict_res)
        predict_label = pd.Series(predict_res, name="label")
        predict_label.to_csv("predict_label.csv", index=False)


if __name__ == '__main__':
    data_process = DataProcess()
    fasttext_model = FastTextModel()
    df_train = data_process.load_data()
    data_process.split_data(df_train)
    model = fasttext_model.train()
    fasttext_model.save_model(model)
    fasttext_model.test()
    fasttext_model.predict_testA()

模型没怎么调参,F1-score达到了0.95:

验证集评测结果:Precision:0.9466, Recall:0.9466, F1-score:0.95

 由于数据量比较大,可以移步公众号:一路向AI,回复文本分类获取。

  • 3
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值