# coding=utf-8
import re
import pandas as pd
import string
import MySQLdb
import jieba
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from sklearn.svm import LinearSVC
#jieba分词
def jieba_tokenizer(x): return jieba.cut(x,cut_all=True)
def partition(x): return x
def filter_html(s):
d = re.compile(r'<[^>]+>',re.S)
return d.sub('',s)
#链接mysql数据库
conn=MySQLdb.connect(host='localhost',user='root',passwd='',db='article',port=3306,charset="utf8")
cursor =conn.cursor()
cursor.execute("SET NAMES utf8")
#训练数据样本
data_ret = pd.DataFrame()
for i in range(0,5):
sql = "SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id>100 ORDER BY a.id ASC LIMIT "+str(i*1000)+",1000"
#print sql
ret = pd.read_sql_query(sql, conn)
data_ret = data_ret.append(ret)
Score = data_ret['classid']
data_ret['artcontent'] = [filter_html(msg) for msg in data_ret['artcontent']]
X_train = data_ret['artcontent']
Y_train = Score.map(partition)
corpus = []
for txt in X_train:
corpus.append(' '.join(jieba_tokenizer(txt)))
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(corpus)
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
clf = LinearSVC().fit(X_train_tfidf, Y_train)
#可以把clf持久化
#测试数据 预测分类
test_set = []
test_txt_data = pd.read_sql_query("SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id<50 ORDER BY a.id ASC", conn)
X_test = [filter_html(msg) for msg in test_txt_data['artcontent']]
for text in X_test:
text=' '.join(jieba_tokenizer(text))
test_set.append(text)
X_new_counts = count_vect.transform(test_set)
X_test_tfidf = tfidf_transformer.transform(X_new_counts)
result = dict()
result = clf.predict(X_test_tfidf)
for i in range(0, len(result)):
print "ID:"+str(test_txt_data['id'][i])+" -> classid:"+str(result[i])
cursor.close()
conn.close()
sklearn学习--读取mysql数据源进行训练样本和预测文本分类
最新推荐文章于 2022-12-18 22:05:40 发布