package cjbayesclassfier;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import edu.umd.cloud9.io.pair.PairOfStrings;
/***
* 第一步:拆分输入文件每一行,得到类输出和条件概率输出
* @author chenjie
*/
public class CJBayesClassfier_Step1 extends Configured implements Tool {
/***
* 映射器:
* 输入:weather.txt
* 其中一行示例如下:Sunny,Hot,High,Weak,No
* 输出:
* key value
* (Sunny,No) 1
* (Hot,No) 1
* (High,No) 1
* (Weak,No) 1
* (CLASS,No) 1
* @author chenjie
*/
public static class CJBayesClassfierMapper extends Mapper<LongWritable, Text, PairOfStrings, LongWritable>
{
PairOfStrings outputKey = new PairOfStrings();
LongWritable outputValue = new LongWritable(1);
@Override
protected void map(
LongWritable key,
Text value,
Context context)
throws IOException, InterruptedException {
String tokens[] = value.toString().split(",");
if(tokens == null || tokens.length < 2)
return;
String classfier = tokens[tokens.length-1];
for(int i = 0; i < tokens.length; i++)
{
if(i < tokens.length-1)
outputKey.set(tokens[i], classfier);
else
outputKey.set("CLASS", classfier);
context.write(outputKey, outputValue);
}
}
}
@Deprecated
public static class CJBayesClassfierReducer extends Reducer<PairOfStrings, LongWritable, PairOfStrings, LongWritable>
{
@Override
protected void reduce(
PairOfStrings key,
Iterable<LongWritable> values,
Context context)
throws IOException, InterruptedException {
Long sum = 0L;
for(LongWritable time : values)
{
sum += time.get();
}
context.write(key, new LongWritable(sum));
}
}
public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, LongWritable, PairOfStrings, Text>
{
/**
* 设置多个文件输出
* */
private MultipleOutputs<PairOfStrings, Text> mos;
@Override
protected void setup(Context context)
throws IOException, InterruptedException {
mos=new MultipleOutputs<PairOfStrings, Text>(context);//初始化mos
}
/***
* 将key值相同的value进行累加
*/
@Override
protected void reduce(
PairOfStrings key,
Iterable<LongWritable> values,
Context context)
throws IOException, InterruptedException {
System.out.println("key =" + key );
Long sum = 0L;
for(LongWritable time : values)
{
sum += time.get();
}
String result = key.getLeftElement() + "," + key.getRightElement() + "," + sum;
if(key.getLeftElement().equals("CLASS"))
mos.write("CLASS", NullWritable.get(), new Text(result));
else
mos.write("OTHERS", NullWritable.get(), new Text(result));
}
/***
* 务必释放资源,否则不会有输出内容
*/
@Override
protected void cleanup(
Context context)
throws IOException, InterruptedException {
mos.close();//释放资源
}
}
public static void main(String[] args) throws Exception
{
args = new String[2];
args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather.txt";
args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes";;
int jobStatus = submitJob(args);
System.exit(jobStatus);
}
public static int submitJob(String[] args) throws Exception {
int jobStatus = ToolRunner.run(new CJBayesClassfier_Step1(), args);
return jobStatus;
}
@SuppressWarnings("deprecation")
@Override
public int run(String[] args) throws Exception {
Configuration conf = getConf();
Job job = new Job(conf);
job.setJobName("Bayes");
MultipleOutputs.addNamedOutput(job, "CLASS", TextOutputFormat.class, Text.class, Text.class);
MultipleOutputs.addNamedOutput(job, "OTHERS", TextOutputFormat.class, Text.class, Text.class);
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setOutputKeyClass(PairOfStrings.class);
job.setOutputValueClass(LongWritable.class);
job.setMapperClass(CJBayesClassfierMapper.class);
job.setReducerClass(CJBayesClassfierReducer2.class);
FileInputFormat.setInputPaths(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
FileSystem fs = FileSystem.get(conf);
Path outPath = new Path(args[1]);
if(fs.exists(outPath))
{
fs.delete(outPath, true);
}
boolean status = job.waitForCompletion(true);
return status ? 0 : 1;
}
}
package cjbayesclassfier;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import edu.umd.cloud9.io.pair.PairOfStrings;
/***
* 第二步:计算概率
* @author chenjie
*
*/
public class CJBayesClassfier_Step2 extends Configured implements Tool {
public static class CJBayesClassfierMapper2 extends Mapper<LongWritable, Text, PairOfStrings, DoubleWritable>
{
PairOfStrings outputKey = new PairOfStrings();
DoubleWritable outputValue = new DoubleWritable(1);
private Map<String,Integer> classMap = new HashMap<String,Integer>();
@Override
protected void setup(Context context) throws IOException, InterruptedException {
FileReader fr = new FileReader("CLASS");
BufferedReader br = new BufferedReader(fr);
String line = null;
while((line = br.readLine()) != null)
{
String tokens[] = line.split(",");
String classfier = tokens[1];
String count = tokens[2];
classMap.put(classfier, Integer.parseInt(count));
}
fr.close();
br.close();
int sum = 0;
for(Map.Entry<String,Integer> entry : classMap.entrySet())
{
sum += entry.getValue();
}
for(Map.Entry<String,Integer> entry : classMap.entrySet())
{
double poss = entry.getValue() * 1.0 / sum;
context.write(new PairOfStrings("CLASS", entry.getKey()), new DoubleWritable(poss));
}
}
@Override
protected void map(
LongWritable key,
Text value,
Context context)
throws IOException, InterruptedException {
String tokens[] = value.toString().split(",");
if(tokens == null || tokens.length < 3)
return;
String X = tokens[0];
String classfier = tokens[1];
Integer count = Integer.valueOf(tokens[2]);
outputKey.set(X, classfier);
Integer classCount = classMap.get(classfier);
outputValue.set(count * 1.0 / classCount);
context.write(outputKey, outputValue);
}
}
public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, DoubleWritable, NullWritable, Text>
{
@Override
protected void reduce(
PairOfStrings key,
Iterable<DoubleWritable> values,
Context context)
throws IOException, InterruptedException {
for(DoubleWritable dw : values)
context.write(NullWritable.get(), new Text(key.getLeftElement() + "," + key.getRightElement() + "," + dw));
}
}
public static void main(String[] args) throws Exception
{
args = new String[2];
args[0] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes/OTHERS-r-00000";
args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes2";
int jobStatus = submitJob(args);
System.exit(jobStatus);
}
public static int submitJob(String[] args) throws Exception {
int jobStatus = ToolRunner.run(new CJBayesClassfier_Step2(), args);
return jobStatus;
}
@SuppressWarnings("deprecation")
@Override
public int run(String[] args) throws Exception {
Configuration conf = getConf();
Job job = new Job(conf);
job.setJobName("Bayes");
job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setOutputKeyClass(PairOfStrings.class);
job.setOutputValueClass(DoubleWritable.class);
job.setMapperClass(CJBayesClassfierMapper2.class);
job.setReducerClass(CJBayesClassfierReducer2.class);
FileInputFormat.setInputPaths(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
FileSystem fs = FileSystem.get(conf);
Path outPath = new Path(args[1]);
if(fs.exists(outPath))
{
fs.delete(outPath, true);
}
boolean status = job.waitForCompletion(true);
return status ? 0 : 1;
}
}
package cjbayesclassfier;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import edu.umd.cloud9.io.pair.PairOfStrings;
/***
* 第三步:根据上一步计算的概率进行贝叶斯推断
* @author chenjie
*
*/
public class CJBayesClassfier_Step3 extends Configured implements Tool {
public static class CJBayesClassfierMapper3 extends Mapper<LongWritable, Text, Text, LongWritable>
{
LongWritable outputValue = new LongWritable(1);
@Override
protected void map(
LongWritable key,
Text value,
Context context)
throws IOException, InterruptedException {
context.write(value, outputValue);
}
}
public static class CJBayesClassfierReducer3 extends Reducer<Text, LongWritable, Text, Text>
{
private List<String> classfications;
@Override
protected void setup(
Reducer<Text, LongWritable, Text, Text>.Context context)
throws IOException, InterruptedException {
classfications = buildClassfications();
for(String classfication : classfications)
{
System.out.println("分类:" + classfication);
}
buildCJGLTable();
CJGLTable.show();
}
private List<String> buildClassfications() throws IOException {
List<String> list = new ArrayList<String>();
FileReader fr = new FileReader("CLASS");
BufferedReader br = new BufferedReader(fr);
String line = null;
while((line = br.readLine()) != null)
{
String tokens[] = line.split(",");
String classfier = tokens[1];
list.add(classfier);
}
fr.close();
br.close();
return list;
}
private void buildCJGLTable() throws IOException {
FileReader fr = new FileReader("GL");
BufferedReader br = new BufferedReader(fr);
String line = null;
while((line = br.readLine()) != null)
{
String tokens[] = line.split(",");
PairOfStrings key = new PairOfStrings(tokens[0],tokens[1]);
CJGLTable.add(key, Double.valueOf(tokens[2]));
}
fr.close();
br.close();
}
@Override
protected void reduce(
Text key,
Iterable<LongWritable> values,
Context context)
throws IOException, InterruptedException {
System.out.println("key=" + key);
System.out.println("values:");
for(LongWritable lw : values)
{
System.out.println(lw);
}
String [] attributes = key.toString().split(",");
String selectedClass = null;
double maxPosterior = 0.0;
for(String aClass : classfications)
{
System.out.println("对于类别:" + aClass);
double posterior = CJGLTable.getClassGL(aClass);
System.out.println("其概率为:" + posterior);
for(String attr : attributes)
{
System.out.println("\t对于条件:" + attr);
double conGL = CJGLTable.getConditionalGL(attr, aClass);
System.out.println("\t其概率为:" + conGL);
posterior *= CJGLTable.getConditionalGL(attr, aClass);
}
if(selectedClass == null)
{
selectedClass = aClass;
maxPosterior = posterior;
}
else
{
if(posterior > maxPosterior)
{
selectedClass = aClass;
maxPosterior = posterior;
}
}
context.write(key, new Text("贝叶斯分类:" + selectedClass + ",其概率为" + maxPosterior));
}
context.write(key, new Text("最终结果:贝叶斯分类为" + selectedClass + ",其概率为" + maxPosterior));
}
}
public static void main(String[] args) throws Exception
{
args = new String[2];
args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt";
args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes3";
int jobStatus = submitJob(args);
System.exit(jobStatus);
}
public static int submitJob(String[] args) throws Exception {
int jobStatus = ToolRunner.run(new CJBayesClassfier_Step3(), args);
return jobStatus;
}
@SuppressWarnings("deprecation")
@Override
public int run(String[] args) throws Exception {
Configuration conf = getConf();
Job job = new Job(conf);
job.setJobName("Bayes");
job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));
job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes2/part-r-00000" + "#GL"));
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(LongWritable.class);
job.setMapperClass(CJBayesClassfierMapper3.class);
job.setReducerClass(CJBayesClassfierReducer3.class);
FileInputFormat.setInputPaths(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
FileSystem fs = FileSystem.get(conf);
Path outPath = new Path(args[1]);
if(fs.exists(outPath))
{
fs.delete(outPath, true);
}
boolean status = job.waitForCompletion(true);
return status ? 0 : 1;
}
}
package cjbayesclassfier;
import java.util.HashMap;
import java.util.Map;
import edu.umd.cloud9.io.pair.PairOfStrings;
/***
* 保存概率表
* @author chenjie
*/
public class CJGLTable {
private static Map<PairOfStrings,Double> map = new HashMap<PairOfStrings,Double>();
public static void add(PairOfStrings key,Double gl)
{
map.put(key, gl);
}
public static double getClassGL(String aClass)
{
PairOfStrings pos = new PairOfStrings("CLASS",aClass);
return map.get(pos)==null ? 0 : map.get(pos);
}
public static double getConditionalGL(String conditional,String aClass)
{
PairOfStrings pos = new PairOfStrings(conditional,aClass);
return map.get(pos)==null ? 0 : map.get(pos);
}
public static void show()
{
for(Map.Entry<PairOfStrings,Double> entry : map.entrySet())
{
System.out.println(entry);
}
}
}
第一步: 输入:weather.txt -------------------------- Sunny,Hot,High,Weak,No Sunny,Hot,High,Strong,No Overcast,Hot,High,Weak,Yes Rain,Mild,High,Weak,Yes Rain,Cool,Normal,Weak,Yes Rain,Cool,Normal,Strong,No Overcast,Cool,Normal,Strong,Yes Sunny,Mild,High,Weak,No Sunny,Cool,Normal,Weak,Yes Rain,Mild,Normal,Weak,Yes Sunny,Mild,Normal,Strong,Yes Overcast,Mild,High,Strong,Yes Overcast,Hot,Normal,Weak,Yes Rain,Mild,High,Strong,No 输出: CLASS-r-00000 ---------------------- CLASS,No,5 CLASS,Yes,9 OTHERS-r-00000 -------------------------- Cool,No,1 Cool,Yes,3 High,No,4 High,Yes,3 Hot,No,2 Hot,Yes,2 Mild,No,2 Mild,Yes,4 Normal,No,1 Normal,Yes,6 Overcast,Yes,4 Rain,No,2 Rain,Yes,3 Strong,No,3 Strong,Yes,3 Sunny,No,3 Sunny,Yes,2 Weak,No,2 Weak,Yes,6 第二步: 缓存:CLASS-r-00000 ----------------------- CLASS,No,5 CLASS,Yes,9 输入:OTHERS-r-00000 ------------------------ Cool,No,1 Cool,Yes,3 High,No,4 High,Yes,3 Hot,No,2 Hot,Yes,2 Mild,No,2 Mild,Yes,4 Normal,No,1 Normal,Yes,6 Overcast,Yes,4 Rain,No,2 Rain,Yes,3 Strong,No,3 Strong,Yes,3 Sunny,No,3 Sunny,Yes,2 Weak,No,2 Weak,Yes,6 输出: part-r-00000 ---------------------------------- CLASS,No,0.35714285714285715 CLASS,Yes,0.6428571428571429 Cool,No,0.2 Cool,Yes,0.3333333333333333 High,No,0.8 High,Yes,0.3333333333333333 Hot,No,0.4 Hot,Yes,0.2222222222222222 Mild,No,0.4 Mild,Yes,0.4444444444444444 Normal,No,0.2 Normal,Yes,0.6666666666666666 Overcast,Yes,0.4444444444444444 Rain,No,0.4 Rain,Yes,0.3333333333333333 Strong,No,0.6 Strong,Yes,0.3333333333333333 Sunny,No,0.6 Sunny,Yes,0.2222222222222222 Weak,No,0.4 Weak,Yes,0.6666666666666666 第三步: 缓存:CLASS-r-00000 ------------------------------- CLASS,No,5 CLASS,Yes,9 缓存:part-r-00000 ------------------------------------ CLASS,No,0.35714285714285715 CLASS,Yes,0.6428571428571429 Cool,No,0.2 Cool,Yes,0.3333333333333333 High,No,0.8 High,Yes,0.3333333333333333 Hot,No,0.4 Hot,Yes,0.2222222222222222 Mild,No,0.4 Mild,Yes,0.4444444444444444 Normal,No,0.2 Normal,Yes,0.6666666666666666 Overcast,Yes,0.4444444444444444 Rain,No,0.4 Rain,Yes,0.3333333333333333 Strong,No,0.6 Strong,Yes,0.3333333333333333 Sunny,No,0.6 Sunny,Yes,0.2222222222222222 Weak,No,0.4 Weak,Yes,0.6666666666666666 输入:weather_predict.txt --------------------------------- Overcast,Hot,High,Strong 过程: --------------------------------------------- 分类:No 分类:Yes (High, No)=0.8 (Strong, No)=0.6 (Normal, No)=0.2 (Normal, Yes)=0.6666666666666666 (Strong, Yes)=0.3333333333333333 (CLASS, No)=0.35714285714285715 (CLASS, Yes)=0.6428571428571429 (Cool, No)=0.2 (High, Yes)=0.3333333333333333 (Hot, No)=0.4 (Sunny, No)=0.6 (Weak, No)=0.4 (Cool, Yes)=0.3333333333333333 (Mild, No)=0.4 (Overcast, Yes)=0.4444444444444444 (Rain, No)=0.4 (Rain, Yes)=0.3333333333333333 (Weak, Yes)=0.6666666666666666 (Hot, Yes)=0.2222222222222222 (Sunny, Yes)=0.2222222222222222 (Mild, Yes)=0.4444444444444444 key=Overcast,Hot,High,Strong values: 1 对于类别:No 其概率为:0.35714285714285715 对于条件:Overcast 其概率为:0.0 对于条件:Hot 其概率为:0.4 对于条件:High 其概率为:0.8 对于条件:Strong 其概率为:0.6 对于类别:Yes 其概率为:0.6428571428571429 对于条件:Overcast 其概率为:0.4444444444444444 对于条件:Hot 其概率为:0.2222222222222222 对于条件:High 其概率为:0.3333333333333333 对于条件:Strong 其概率为:0.3333333333333333 输出: Overcast,Hot,High,Strong 贝叶斯分类:No,其概率为0.0 Overcast,Hot,High,Strong 贝叶斯分类:Yes,其概率为0.007054673721340388 Overcast,Hot,High,Strong 最终结果:贝叶斯分类为Yes,其概率为0.007054673721340388
使用Spark(原生API)
import org.apache.spark.{SparkConf, SparkContext} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer object CJBayes { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local") val sc = new SparkContext(sparkConf) val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather.txt" val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt" val output = "file:///media/chenjie/0009418200012FF3/ubuntu/weather" val inputRDD = sc.textFile(input) val trainDataSize = inputRDD.count() val mapRDD = inputRDD.flatMap{line=> val result = ArrayBuffer[Tuple2[Tuple2[String,String],Integer]]() val tokens = line.split(",") val classfier = tokens(tokens.length-1) for(i <- 0 until tokens.length-1){ result += (Tuple2(Tuple2(tokens(i),classfier),1)) } result += (Tuple2(Tuple2("CLASS",classfier),1)) result } val reduceRDD = mapRDD.reduceByKey(_+_) val countsMap = reduceRDD.collectAsMap() val PT = new mutable.HashMap[Tuple2[String,String],Double]() val CLASSFICATIONS = new mutable.ArrayBuffer[String]() countsMap.foreach(item=>{ val k = item._1 val v:Integer = item._2 val condition = k._1 val classfication = k._2 if(condition.equals("CLASS")){ PT.put(k,v.toDouble/trainDataSize.toDouble) CLASSFICATIONS += k._2 } else{ val k2 = new Tuple2[String,String]("CLASS",classfication) val count = countsMap.get(k2) if(count==null){ PT.put(k,0.0) } else{ PT.put(k,v.toDouble/count.get) } } }) PT.foreach(println) val predict = sc.textFile(predictFile) predict.map(line=>{ val attributes = line.split(",") var selectedClass = "" var maxPosterior = 0.0 for(aClass <- CLASSFICATIONS){ println("对于类:" + aClass) var posterior: Double = if (PT.get(Tuple2("CLASS", aClass)) == None) 0 else PT.get(Tuple2("CLASS", aClass)).get println("其概率为:" + posterior) for(attr <- attributes){ println("\t对于条件:" + attr) val probability:Double = if (PT.get(Tuple2(attr,aClass)) == None) 0 else PT.get(Tuple2(attr,aClass)).get println("\t其概率为:" + probability) posterior *= probability if(selectedClass == null){ selectedClass = aClass maxPosterior = posterior } else{ if(posterior > maxPosterior){ selectedClass = aClass maxPosterior = posterior } } } } line + "," + selectedClass + ":" + maxPosterior }).foreach(println) }
使用Spark(mllib机器学习库)
import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.{SparkConf, SparkContext} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer object CJBayes { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local") val sc = new SparkContext(sparkConf) val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather1.txt" val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt" val data = sc.textFile(input) val parsedData =data.map { line => val parts =line.split(',') LabeledPoint(parts(1).toDouble,Vectors.dense(parts(0).split(' ').map(_.toDouble))) } // 把数据的100%作为训练集,0%作为测试集. val splits = parsedData.randomSplit(Array(1.0,0.0),seed = 11L) val training =splits(0) val test =splits(1) //获得训练模型,第一个参数为数据,第二个参数为平滑参数,默认为1,可改 val model =NaiveBayes.train(training,lambda = 1.0) //对模型进行准确度分析 val predictionAndLabel= test.map(p => (model.predict(p.features),p.label)) val accuracy =1.0 *predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() println("accuracy-->"+accuracy) println("Predictionof (2.0,1.0,1.0,2.0):"+model.predict(Vectors.dense(2.0,1.0,1.0,2.0))) } }