京东评论情感分类器(基于bag-of-words模型)
近期在本来在研究paraVector模型,想拿bag-of-words来做对照。
数据集是京东的评论,经过人工挑选,选出一批正面和负面的评论。
实验的数据量不大,340条正面,314条负面。我一般拿200正面和200负面做训练,剩下做測试。
做着做着,领悟了一些机器学习的道理。发现,对于不同的数据集,效果是不同的。
对于特定的数据集,随便拿来一套模型可能并不适用。
对于这些评论,我感觉就是bag-of-words模型靠谱点。
由于这些评论的特点是语句简短,关键词重要。
paraVector模型感觉比較擅长长文本的分析,注重上下文。
事实上我还结合了两个模型来做一个新的模型,准确率有点提高,可是不大。可能我数据量太少了。
整理了一下思路,做了个评论情感分类的demo。
特征抽取是bag-of-words模型。
分类器是自己想的一个模型,结合了knn和kmeans的思想。依据对于正负样本的训练集分别求出两个聚类中心,每次新样本进来,跟两个中心做距离比較。
下面是demo的代码:
import java.util.Scanner;
public class BowInterTest {
public static void main(String[] args) throws Throwable
{
// TODO Auto-generated method stub
BowModel bm = new BowModel("/media/linger/G/sources/comment/test/all");//all=good+bad
double[][] good = bm.generateFeature("/media/linger/G/sources/comment/test/good",340);
double[][] bad = bm.generateFeature("/media/linger/G/sources/comment/test/bad",314);
bm.train(good,0,200,bad,0,200);//指定训练数据
//bm.test(good, 200, 340, bad, 200, 314);//指定測试数据
//交互模式
Scanner sc = new Scanner(System.in);
while(sc.hasNext())
{
String doc = sc.nextLine();
double[] fea = bm.docFea(doc);
Norm.arrayNorm2(fea);
double re = bm.predict(fea);
if(re<0)
{
System.out.println("good:"+re);
}
else
{
System.out.println("bad:"+re);
}
}
}
}
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.StringTokenizer;
public class BowModel extends KnnCoreModel
{
Dict dict;
DocFeatureFactory dff;
public BowModel(String path) throws IOException
{
dict = new Dict();
dict.loadFromLocalFile(path);
dff = new DocFeatureFactory(dict.getWord2Index());
}
public double[] docFea(String doc)
{
return dff.getFeature(doc);
}
public double[][] generateFeature(String docsFile,int docNum) throws IOException
{
double[][] featureTable = new double[docNum][];
int docIndex=0;
File file = new File(docsFile);
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
while(true)
{
String line=br.readLine();
if(line == null)
break;
featureTable[docIndex++] = dff.getFeature(line);
}
br.close();
return featureTable;
}
}
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Map.Entry;
public class Dict
{
HashMap<String,Integer> word2Index =null;
Hashtable<String,Integer> word2Count = null;
void loadFromLocalFile(String path) throws IOException
{
word2Index = new HashMap<String,Integer>();
word2Count = new Hashtable<String,Integer>();
int index = 0;
File file = new File(path);
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
while(true)
{
String line=br.readLine();
if(line == null)
break;
StringTokenizer tokenizer=new StringTokenizer(line," ");
while(tokenizer.hasMoreElements())
{
String term=tokenizer.nextToken();
if(word2Count.containsKey(term))
{
int freq=word2Count.get(term)+1;
word2Count.put(term, freq);
}
else
{
word2Count.put(term, 1);
word2Index.put(term, index++);
}
}
}
br.close();
}
public HashMap<String,Integer> getWord2Index()
{
return word2Index;
}
public void print()
{
Iterator<Entry<String, Integer>> iter=word2Count.entrySet().iterator();
while(iter.hasNext())
{
Entry<String,Integer> item=(Entry<String,Integer>)iter.next();
if(item.getValue()>30)
System.out.printf("%s,%d\n",item.getKey(),item.getValue());
}
}
public static void main(String[] args) throws IOException
{
// TODO Auto-generated method stub
Dict dict = new Dict();
dict.loadFromLocalFile("/media/linger/G/sources/comment/test/all");
dict.print();
}
}
import java.util.HashMap;
import java.util.StringTokenizer;
public class DocFeatureFactory
{
HashMap<String,Integer> word2Index;
double[] feature;
int dim;
public DocFeatureFactory(HashMap<String,Integer> w2i)
{
word2Index = w2i;
dim = w2i.size();
}
double[] getFeature(String doc)
{
feature = new double[dim];
int wordNum=0;
//while(wordNum<25)//这个作用跟规范化的一样啊
//{
StringTokenizer tokenizer=new StringTokenizer(doc," ");
while(tokenizer.hasMoreElements())
{
wordNum++;
String term =tokenizer.nextToken();
Integer index = word2Index.get(term);
if(index==null) continue;
feature[index]++;
}
//}
return feature;
}
public static void main(String[] args)
{
// TODO Auto-generated method stub
}
}
public class KnnCoreModel
{
double[] good_standard ;
double[] bad_standard ;
public void train(double[][] good,int train_good_start,int train_good_end,
double[][] bad,int train_bad_start,int train_bad_end)
{
//double[][] good = generateFeature("/media/linger/G/sources/comment/test/good",340);
//double[][] bad = generateFeature("/media/linger/G/sources/comment/test/bad",314);
//double[] arv = new double[good[0].length];
//double[] var = new double[good[0].length];
//2范式归一化
Norm.tableNorm2(good);
Norm.tableNorm2(bad);
good_standard = new double[good[0].length];
bad_standard = new double[bad[0].length];
for(int i=train_good_start;i<train_good_end;i++)
{
for(int j=0;j<good[i].length;j++)
{
good_standard[j]+=good[i][j];
}
}
//System.out.println("\ngood core:");
for(int j=0;j<good_standard.length;j++)
{
good_standard[j]/=(train_good_end-train_good_start);
//System.out.printf("%f,",good_standard[j]);
}
for(int i=train_bad_start;i<train_bad_end;i++)
{
for(int j=0;j<bad[i].length;j++)
{
bad_standard[j]+=bad[i][j];
}
}
//System.out.println("\nbad core:");
for(int j=0;j<bad_standard.length;j++)
{
bad_standard[j]/=(train_bad_end-train_bad_start);
//System.out.printf("%f,",bad_standard[j]);
}
}
public void test(double[][] good,int test_good_start,int test_good_end,
double[][] bad,int test_bad_start,int test_bad_end)
{
Norm.tableNorm2(good);
Norm.tableNorm2(bad);
int error=0;
double good_dis;
double bad_dis;
//test
for(int i=test_good_start;i<test_good_end;i++)
{
good_dis= distance(good[i],good_standard);
bad_dis = distance(good[i],bad_standard);
//good_dis= allDistance(good[i],good,train_good_start,train_good_end);
//bad_dis = allDistance(good[i],bad,train_bad_start,train_bad_end);
double dis= good_dis-bad_dis;
if(dis>0)
{
error++;
System.out.println("-:"+(dis));
}
else
{
System.out.println("+:"+(dis));
}
}
for(int i=test_bad_start;i<test_bad_end;i++)
{
good_dis= distance(bad[i],good_standard);
bad_dis = distance(bad[i],bad_standard);
//good_dis= allDistance(bad[i],good,train_good_start,train_good_end);
//bad_dis = allDistance(bad[i],bad,train_bad_start,train_bad_end);
double dis= good_dis-bad_dis;
if(dis>0)
{
System.out.println("+:"+(dis));
}
else
{
error++;
System.out.println("-:"+(dis));
}
}
int count = (test_good_end-test_good_start+test_bad_end-test_bad_start);
System.out.println("\nerror:"+error+",total:"+count);
System.out.println("error rate:"+(double)error/count);
System.out.println("acc rate:"+(double)(count-error)/count);
}
public double predict(double[] fea)
{
double good_dis = distance(fea,good_standard);
double bad_dis = distance(fea,bad_standard);
return good_dis-bad_dis;
}
private double distance(double[] src,double[] dst)
{
double sum=0;
if(src.length!=dst.length)
{
System.out.println("size not right!");
return sum;
}
for(int i=0;i<src.length;i++)
{
sum+=(dst[i]-src[i])*(dst[i]-src[i]);
}
//return Math.sqrt(sum);
return sum;
}
private double allDistance(double[]src,double[][] trainSet,int start,int end)
{
double sum=0;
for(int i=start;i<end && i<trainSet.length;i++)
{
sum+=distance(src,trainSet[i]);
}
return sum;
}
}
public class Norm {
public static void arrayNorm2(double[] array)
{
double sum;
sum=0;
for(int j=0;j<array.length;j++)
{
sum +=array[j]*array[j];
}
if(sum == 0) return;
sum = Math.sqrt(sum);
for(int j=0;j<array.length;j++)
{
array[j]/=sum;
}
}
public static void tableNorm2(double[][] table)
{
double sum;
for(int i=0;i<table.length;i++)
{
sum=0;
for(int j=0;j<table[i].length;j++)
{
sum +=table[i][j]*table[i][j];
}
if(sum == 0) continue;
sum = Math.sqrt(sum);
for(int j=0;j<table[i].length;j++)
{
table[i][j]/=sum;
}
}
}
}
数据集下载:http://download.csdn.net/detail/linger2012liu/7758939
本文作者:linger
本文链接:http://blog.csdn.net/lingerlanlan/article/details/38418277