EM算法文本分类

100篇文章分属于三类Business,Politics,Sport,用EM算法进行分类。然后用额外30篇进行测试。

用python实现

from stemming.porter2 import stem
import re
import os
import random
import numpy as np

# rename all files in a directory. 0.txt........n.txt
def rename_file(path, start = 0):
	for file in os.listdir(path):
		if os.path.isfile(os.path.join(path,file)):
			os.rename(os.path.join(path,file),os.path.join(path,str(start)+'.txt'))
			start = start + 1

# remove adverb, numbers and some other meaningless words.
def filter_word(word):
	word=word.strip("@[].?'\" ").lower()
	if not word.isalpha() or len(word)<3 or word.endswith("ly"):
		return ''
	return stem(word)

# get an array of  unique words from all articles coverage means how many articles including the keyword.
def  get_all_words_with_coverage(num, path):
	os.chdir(path)
	word_coverage = {}
	for i in range(num):
		textfile = open(str(i)+".txt", "r", encoding = "utf-8")
		words_in_article = re.split('\s+', textfile.read())		
		# filter the words and remove empty string and the duplicates
		words_in_article = list(set([x for x in [filter_word(i) for i in words_in_article] if x]))
		#generate a dictionary {word1: number_of_articles, word2: number_of_articles, ...}	
		for w in set(words_in_article):
			if w in word_coverage.keys():
				word_coverage[w] = word_coverage[w] + 1
			else:
				word_coverage[w] = 1
	return word_coverage

# After using get_all_words_with_coverage(), we can investigate it and select keywords
# Finally the keywords as below:
keywords=["business","sale","percent","quarter","company","money","investment","manager","advertise","market","politics","clinton","donald","nominee","democratic","president","republican","national","official","candidate","sport","game","play","team","championship","league","season","win","nba","loss"]
#keywords = ["company","market","business","sale","president","official","clinton","donald","game","team","play","sport"] #15/100, 7/30
def  get_article_vector(num, path) :
	os.chdir(path)
	obs = []
	for i in range(num):
		textfile = open(str(i)+".txt", "r", encoding="utf-8")
		words_in_article = re.split('\s+', textfile.read())
		# filter the words and remove empty string and the duplicates
		words_in_article = list(set([x for x in [filter_word(i) for i in words_in_article] if x]))
		article_vector = []
		for word in keywords:
			if word in words_in_article:
				article_vector.append(1)
			else:
				article_vector.append(0)
		obs.append(article_vector)
	return obs

def get_likelihood(obs,probs):
	likelihood = 1
	for x in range(len(obs)):
		if obs[x]:
			likelihood = likelihood * probs[x]
		else:
			likelihood = likelihood * (1-probs[x])
	return likelihood

def doTraining(observations, k = 3):
	if not type(observations) is np.ndarray:
		observations = np.array(observations)
	n, p= observations.shape
	pBPS = np.array([[random.random() for col in range(p)] for row in range(k)])
	z = np.array([[0.0 for col in range(k)] for row in range(n)])
	mixture = [0.3,0.3,0.4]
	delta = 1e-8
	improvement = float("inf")
	times =1 # iteration counter
	while improvement > delta and times < 100 :
		pBPS_old = np.copy(pBPS)
		for i in range(n) :
			for j in range(k) :
				z[i,j] = mixture[j]*get_likelihood(observations[i],pBPS[j])
		for i in range(n) :
			s = sum(z[i,:])
			z[i,:] = z[i,:] / s	
		for j in range(k) :
			mixture[j] = sum(z[:,j])
			for i in range(p):
				pBPS[j][i] = sum(observations[:,i]*z[:,j]) / mixture[j]
		mixture = [x/n for x in mixture]
		improvement = abs(pBPS-pBPS_old).max()
		times = times + 1
	print(times)
	print("mixture",mixture)
	print("Training result:")
	print(np.around(z,2))
	return(pBPS)

def doTesting(test_obs, theta):
	n = len(test_obs)
	k = len(theta)
	z = np.array([[0.0 for col in range(k)] for row in range(n)])
	for i in range(n) :
		for j in range(k) :
			z[i,j] = get_likelihood(test_obs[i],theta[j])
	for i in range(n) :
		s = sum(z[i,:])
		z[i,:] = z[i,:] / s
	print("Testing result:")
	print(np.around(z,2))

observations=get_article_vector(100, "_path_to_train")
theta=doTraining(observations)
test_data=get_article_vector(30, "_path_to_test")
doTesting(test_data,theta)

转载于:https://my.oschina.net/fazheng/blog/681248

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值