文本分类的目的是将文本文档分为不同的类,这是NLP中非常重要的分析手段。这里将使用一种技术,它基于一种叫作tf-idf的统计数据,它表示词频-逆文档频率(term frequency—inversedocument frequency)。这个统计工具有助于理解一个单词在一组文档中对某一个文档的重要性。它可以作为特征向量来做文档分类。
实际上就是利用现有数据或者已有模型来预测输入的文档的类别。
分为以下几个基本步骤:
(1)先期要人工进行定义,当然,实际中,结合着机器算法共同完成效果更佳。类别定义,可以用词典映射的方式进行,这些类型我们以新闻组数据集为例。
category_map ={'misc.forsale': 'Sales', 'rec.motorcycles': 'Motorcycles','rec.sport.baseball':'Baseball', 'sci.crypt': 'Cryptography','sci.space': 'Space'}
(2)基于这种定义类型,进行训练数据的加载。
training_data =fetch_20newsgroups(subset='train', categories=category_map.keys(),
shuffle=True, random_state=7)
(3)导入特征提取器,然后进行特征提取。我们直接加载sklearn中CountVectorizer特征提取器。
from sklearn.feature_extraction.textimport CountVectorizer
vectorizer = CountVectorizer()
X_train_termcounts =vectorizer.fit_transform(training_data.data)
(4)利用分类器,这里选择多项式朴素贝叶斯(Multinomial Naive Bayes)分类器进行分类,且利用tfidfTransformer先进行特征向量转换,然后再根据这个向量再进行分类。
from sklearn.naive_bayes importMultinomialNB
from sklearn.feature_extraction.text import TfidfTransformer
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_ termcounts)
利用tf-idf变换器定义对象,并对步骤(3)得到的特征向量进行tf-idf转换,然后利用多项式朴素贝叶斯分类器进行训练分类。
classifier = MultinomialNB().fit(X_train_tfidf, training_data.target)
(5)到这,训练完毕,输入数据可以进行分类预测了。
首先,用词频统计转换输入数据:
X_input_termcounts = vectorizer.transform(input_data)
然后,用tf-idf变换器变换输入数据:
X_input_tfidf = tfidf_transformer.transform(X_input_termcounts)
最后,用训练过的分类器来对输入数据的tfidf向量进行预测,也即就是对输入句子进行输出类型预测:
# 预测输出类型
predicted_categories = classifier.predict(X_input_tfidf)
(6)结果输出:
for sentence, category in zip(input_data, predicted_categories):
print ('\nInput:', sentence,'\nPredicted category:', category_map[training_data\
.target_names[category]])
当然,需要导入包:fromsklearn.datasets import fetch_20newsgroups,这个数据集中包括训练集和测试集,共20个新闻组,感兴趣的读者,可以利用下面的代码读取数据集。
from sklearn.datasets importfetch_20newsgroups
newsgroups_train =fetch_20newsgroups(subset='train')
newsgroups_test =fetch_20newsgroups(subset='test')
print (len(newsgroups_train.data))
print (len(newsgroups_test.data))
news =fetch_20newsgroups(subset='all')
print (len(news.data))
这个功能比较简单,以后有时间,我们将设计一款实用型分类器。大家,加油!