多线程之K-近邻算法
之前曾经写过在MapReduce和Spark下的k近邻算法,最近在看多线程,于是也就顺便写写多线程下的k近邻算法,k近邻算法的详细原理就不写了,可以看我之前的博客
MapReduce之KNN算法
在这里直接使用UCI机器学习资源库中的Bank Marketing数据集就可(忘记有这么个玩意了),在这里依旧采用欧式距离作为指标。在这篇博文中使用Bank Marketing数据集中age,education,balance,house,duration,campaign,married这几个属性来进行测试,通过一连串的测试数据来验证判断married这个参数
在这里由于一些参数是字符串形式,在这里没必要使用词向量了,映射成一个浮点型的参数就可以
先写主要实现类KnnClassifier,该类存储了训练数据集和数值k(用于确定某个实例标签的数量)
import java.util.*;
public class SingleKnn {
private final List <? extends Sample> dataSet; //存储训练数据集
private int k; //存储标签的范围数量
public SingleKnn(List <? extends Sample> dataSet, int k) {
this.dataSet = dataSet;
this.k = k;
}
public String classify(Sample sample) {
Distance[] distances = new Distance[dataSet.size()];
int index = 0;
for(Sample localExample : dataSet) {
// System.out.println("当前正在指定第 "+index+" 行计算");
distances[index] = new Distance();
distances[index].setIndex(index);
distances[index].setDistance(EuclideanDistanceCalculator
.calculate(localExample, sample));
index ++;
}
Arrays.sort(distances); //按照欧式距离的计算结果进行排序
Map<String, Integer> results = new HashMap<>();
for(int i = 0; i < k; i++) { //取距离最小的前k个
Sample localExample = dataSet.get(distances[i].getIndex());
String tag = localExample.getTag();
results.merge(tag, 1, (a,b) ->a+b); //对前k个的标签进行统计
}
return Collections.max(results.entrySet(),
Map.Entry.comparingByValue()).getKey(); //取value最大的key
}
}
由上述的一段代码可以看出,我们需要定义三个类,Sample,Distance和计算欧式距离的辅助类EuclideanDistanceCalculator
先编写计算欧式距离的辅助类
package com.Knnclassifier;
public class EuclideanDistanceCalculator {
public static double calculate (Sample example1, Sample example2) {
double ret = 0.0d;
double[] data1 = example1.getExample();
double[] data2 = example2.getExample();
if(data1.length != data2.length) {
throw new IllegalArgumentException("Vector doesn't have the same length");
}
for(int i=0; i<data1.length; i++) {
ret = Math.pow(data1[i] - data2[i],2);
}
return Math.sqrt(ret);
}
}
好,现在我们需要编写类Sample,在可以考虑编写一个Sample的扩展类,以便扩展更过的的功能
Sample类如下
package com.Knnclassifier;
public class Sample {
private double[] example;
private String tag;
public double[] getExample() {
return example;
}
public void setExample(double[] example) {
this.example = example;
}
public String getTag() {
return tag;
}
public void setTag(String tag) {
this.tag = tag;
}
}
扩展类BankMarketing
毕竟只是学习如何使用多线程,没必要处理全部的属性,调几个重要的即可
package com.Knnclassifier;
import java.util.HashMap;
import java.util.Map;
public class BankMarketing extends Sample {
public static final Integer AGE = 0; //文件切割后所在的位置
public static final Integer EDUCATION = 3;
public static final Integer BALANCE = 5;
public static final Integer HOUSING = 6;
public static final Integer DURATIOn = 11;
public static final Integer CAMPAIGN = 12;
public static final Integer MARRIED = 2;
public static final Map<String, Double> map = new HashMap<String, Double>(){{
put("\"tertiary\"",0d);
put("\"primary\"", 1d);
put("\"secondary\"", 2d);
put("\"unknown\"", 3d);
put("\"no\"",0d);
put("\"yes\"", 1d);
}};
private String age;
private String education;
private String balance;
private String house;
private String duration;
private String campaign;
public String getCampaign() {
return campaign;
}
public void setCampaign(String campaign) {
this.campaign = campaign;
}
public String getAge() {
return age;
}
public void setAge(String age) {
this.age = age;
}
public String getEducation() {
return education;
}
public void setEducation(String education) {
this.education = education;
}
public String getBalance() {
return balance;
}
public void setBalance(String balance) {
this.balance = balance;
}
public String getHouse() {
return house;
}
public void setHouse(String house) {
this.house = house;
}
public String getDuration() {
return duration;
}
public void setDuration(String duration) {
this.duration = duration;
}
public BankMarketing(String input) { //生成双精度浮点数组
if(input == null || input.length() == 0) {
return ;
}
String[] args = input.split(";");
this.setAge(args[AGE]);
this.setEducation(args[EDUCATION]);
this.setBalance(args[BALANCE]);
this.setHouse(args[HOUSING]);
this.setDuration(args[DURATIOn]);
this.setCampaign(args[CAMPAIGN]);
double[] example = new double[6];
example[0] = Double.parseDouble(this.getAge());
example[1] = map.get(this.getEducation());
example[2] = Double.parseDouble(this.getBalance());
example[3] = map.get(this.getHouse());
example[4] = Double.parseDouble(this.getDuration());
example[5] = Double.parseDouble(this.getCampaign());
this.setExample(example);
this.setTag(args[MARRIED]);
}
@Override
public String toString() {
return "BankMarketing{" +
"age='" + age + '\'' +
", education='" + education + '\'' +
", balance='" + balance + '\'' +
", house='" + house + '\'' +
", duration='" + duration + '\'' +
", campaign='" + campaign + '\'' +
", tag=‘" + this.getTag() + '\'' +
'}';
}
}
好,现在我们还查一个距离类Distance
package com.Knnclassifier;
public class Distance implements Comparable<Distance> {
private Integer index;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
private double distance;
@Override
public int compareTo(Distance o) { //需要重写比较方法
if (this.distance < o.getDistance())
return -1;
else if (this.distance > o.getDistance())
return 2;
return 0;
}
}
写在编写测试类启动测试以下整体效果
package com.Knnclassifier;
import com.Knnclassifier.Parallel.KnnClassifierParallelIndividual;
import java.io.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class Main {
public static void main(String[] args) throws IOException {
File file = new File("/home/下载/bank-full.csv"); //读取文件
BufferedReader br = new BufferedReader(new FileReader(file));
String str = "";
Set<String> s = new HashSet<>();
List<BankMarketing> bankList = new ArrayList<>();
boolean head =true;
while((str = br.readLine())!=null) {
if(head) {
head = false;
continue;
}
BankMarketing bankMarketing = new BankMarketing(str);
bankList.add(bankMarketing);
}
br.close();
System.out.println(bankList.get(0).toString());
SingleKnn singleKnn = new SingleKnn(bankList, 2);
Sample samplle = new Sample();
double[] a = {30.0, 1.0, 1787.0, 0.0, 79.0, 1.0}; //懒得转译了,直接写成double类型数组了
samplle.setExample(a);
long start = System.currentTimeMillis();
samplle.setTag(singleKnn.classify(samplle));
long end = System.currentTimeMillis();
System.out.println(samplle.getTag() + " 耗时 "+ (end-start));
}
}
执行以下,看看结果