/**
* 朴素贝叶斯分类器 拉普拉变化的重要性(暂未实现) 小样本数据有坑 特征为离散型数值化
* @author ysh 1208706282
*
*/
public class NavieBayes {
Map<Integer,Integer> labelInfo;
Map<String,FeatureInfo> featureInfo;
List<Sample> samples;
static class Sample{
int label;
List<Integer> feature;
}
static class FeatureInfo{
int label;
int featureId;
int featureValue;
int count;
double rate;
}
/**
* 加载数据
* @param path
* @param regex
* @throws Exception
*/
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Integer>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Integer(splits[i]));
}
sample.label = Integer.valueOf(splits[splits.length-1]);
samples.add(sample);
}
reader.close();
}
/**
* 加载验证测试集
* @param path
* @param regex
* @throws Exception
*/
public List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
List<Sample> samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Integer>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Integer(splits[i]));
}
if(hasLabel){
sample.label = Integer.valueOf(splits[splits.length-1]);
}
samples.add(sample);
}
reader.close();
return samples;
}
public void laplaceSmooth(){
}
public void train(){
featureInfo = new HashMap<String,FeatureInfo>();
labelInfo = new HashMap<Integer,Integer>();
String key = null;
FeatureInfo info = null;
for(Sample sample:samples) {
if(null == labelInfo.get(sample.label)){
labelInfo.put(sample.label, 1);
}else{
labelInfo.put(sample.label, labelInfo.get(sample.label)+1);
}
for(int i=0;i<sample.feature.size();i++){
key = sample.label+";"+i+";"+sample.feature.get(i);
info = featureInfo.get(key);
if(null == info){
info = new FeatureInfo();
info.count = 1;
info.featureId = i;
info.featureValue = sample.feature.get(i);
info.label = sample.label;
featureInfo.put(key, info);
}else{
info.count += 1;
}
}
}
Iterator<Entry<Integer,Integer>> iter = labelInfo.entrySet().iterator();
Entry<Integer,Integer> entry = null;
while(iter.hasNext()){
entry = iter.next();
System.out.println("label: "+entry.getKey()+" count: "+entry.getValue());
}
Set<String> set = featureInfo.keySet();
for(String str:set){
System.out.println(str+" count:"+featureInfo.get(str).count);
}
}
public int classify(Sample sample){
int label = 0;
double max = -1;
String key = null;
FeatureInfo info = null;
Set<Integer> set = labelInfo.keySet();
for(Integer la:set){
double rate = 1;
for(int i=0;i<sample.feature.size();i++){
key = la.intValue()+";"+i+";"+sample.feature.get(i);
info = featureInfo.get(key);
if(info != null){
rate *= (1.0*info.count/labelInfo.get(la));
}else{
//System.out.println("error");
rate *= (1.0/labelInfo.get(la));
}
}
rate *= (1.0*labelInfo.get(la)/samples.size());
if(rate > max){
max = rate;
label = la.intValue();
}
//System.out.println("label: "+la+" rate:"+rate);
}
return label;
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
String pathTrain = "F:/uci/data/car/car.train_train";
String pathTest = "F:/uci/data/car/car.train_test";
NavieBayes nb = new NavieBayes();
nb.loadData(pathTrain, ",");
nb.train();
List<Sample> test = nb.loadTestData(pathTest,true,",");
int count = 0;
for(Sample sample:test){
int predict = nb.classify(sample);
System.out.println("label: "+sample.label+" predict: "+predict);
if(predict == sample.label){
count++;
}
}
System.out.println("right rate: "+(count*1.0/test.size()));
}
}
* 朴素贝叶斯分类器 拉普拉变化的重要性(暂未实现) 小样本数据有坑 特征为离散型数值化
* @author ysh 1208706282
*
*/
public class NavieBayes {
Map<Integer,Integer> labelInfo;
Map<String,FeatureInfo> featureInfo;
List<Sample> samples;
static class Sample{
int label;
List<Integer> feature;
}
static class FeatureInfo{
int label;
int featureId;
int featureValue;
int count;
double rate;
}
/**
* 加载数据
* @param path
* @param regex
* @throws Exception
*/
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Integer>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Integer(splits[i]));
}
sample.label = Integer.valueOf(splits[splits.length-1]);
samples.add(sample);
}
reader.close();
}
/**
* 加载验证测试集
* @param path
* @param regex
* @throws Exception
*/
public List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
List<Sample> samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Integer>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Integer(splits[i]));
}
if(hasLabel){
sample.label = Integer.valueOf(splits[splits.length-1]);
}
samples.add(sample);
}
reader.close();
return samples;
}
public void laplaceSmooth(){
}
public void train(){
featureInfo = new HashMap<String,FeatureInfo>();
labelInfo = new HashMap<Integer,Integer>();
String key = null;
FeatureInfo info = null;
for(Sample sample:samples) {
if(null == labelInfo.get(sample.label)){
labelInfo.put(sample.label, 1);
}else{
labelInfo.put(sample.label, labelInfo.get(sample.label)+1);
}
for(int i=0;i<sample.feature.size();i++){
key = sample.label+";"+i+";"+sample.feature.get(i);
info = featureInfo.get(key);
if(null == info){
info = new FeatureInfo();
info.count = 1;
info.featureId = i;
info.featureValue = sample.feature.get(i);
info.label = sample.label;
featureInfo.put(key, info);
}else{
info.count += 1;
}
}
}
Iterator<Entry<Integer,Integer>> iter = labelInfo.entrySet().iterator();
Entry<Integer,Integer> entry = null;
while(iter.hasNext()){
entry = iter.next();
System.out.println("label: "+entry.getKey()+" count: "+entry.getValue());
}
Set<String> set = featureInfo.keySet();
for(String str:set){
System.out.println(str+" count:"+featureInfo.get(str).count);
}
}
public int classify(Sample sample){
int label = 0;
double max = -1;
String key = null;
FeatureInfo info = null;
Set<Integer> set = labelInfo.keySet();
for(Integer la:set){
double rate = 1;
for(int i=0;i<sample.feature.size();i++){
key = la.intValue()+";"+i+";"+sample.feature.get(i);
info = featureInfo.get(key);
if(info != null){
rate *= (1.0*info.count/labelInfo.get(la));
}else{
//System.out.println("error");
rate *= (1.0/labelInfo.get(la));
}
}
rate *= (1.0*labelInfo.get(la)/samples.size());
if(rate > max){
max = rate;
label = la.intValue();
}
//System.out.println("label: "+la+" rate:"+rate);
}
return label;
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
String pathTrain = "F:/uci/data/car/car.train_train";
String pathTest = "F:/uci/data/car/car.train_test";
NavieBayes nb = new NavieBayes();
nb.loadData(pathTrain, ",");
nb.train();
List<Sample> test = nb.loadTestData(pathTest,true,",");
int count = 0;
for(Sample sample:test){
int predict = nb.classify(sample);
System.out.println("label: "+sample.label+" predict: "+predict);
if(predict == sample.label){
count++;
}
}
System.out.println("right rate: "+(count*1.0/test.size()));
}
}