多线程之K-近邻算法(二) 细粒度并发版本
上一篇文章
多线程之K-近邻算法(一)
讲述了K-近邻算法在串行条件下的编写思路,在这篇博文中通过执行器来实现K-近邻算法的细粒度并发版本。java执行器有很多功能,在这里用到哪个再介绍哪个的功能
由上一篇博客继续展开,在这里,可以发现有以下的地方可以进行优化
- 距离计算
- 距离排序
在这里,需要实现KnnClassifierParrallelIndividual类,该类依旧和之前的类一样存放训练的数据集,参数k,不同的地方在于添加了执行并行任务的ThreadPoolExecutor对象和用于存放执行器中的工作线程属性,以及是否需要并行排序的属性。
执行器: 执行器不需要创建任何Thread对象,可以通过重新使用线程来所见线程创建带来的开销,易于控制计算机资源
需要为每个计算的距离创建一个任务,并将其发送给执行器,主线程等待这些任务执行结束
KnnClassifierParallelIndividual类设计如下
import com.Knnclassifier.Distance;
import com.Knnclassifier.Sample;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
public class KnnClassifierParallelIndividual {
private final List<? extends Sample> dataSet;
private final int k;
private ThreadPoolExecutor executor;
private final int numThreads;
private final boolean parallelSort;
public KnnClassifierParallelIndividual(List<? extends Sample> dataSet, int k, int factor, boolean parallelSort) {
this.dataSet = dataSet; //存放的训练数据集
this.k = k;
//factor为从处理器中获取的的线程数
this.numThreads = factor * (Runtime.getRuntime().availableProcessors());
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(this.numThreads); //线程池
this.parallelSort = parallelSort;
}
public String classify (Sample sample) throws Exception {
Distance[] distances = new Distance[dataSet.size()];
CountDownLatch endController = new CountDownLatch(dataSet.size());
int index = 0;
for(Sample localSample : dataSet) {
IndividualDistanceTask task = new IndividualDistanceTask(distances,
index, localSample, sample, endController);
executor.execute(task);
index++;
// System.out.println("正在执行任务的线程的大概数量 "+executor.getActiveCount());
}
endController.await(); //线程挂起,到达指定条件结束
if(parallelSort) {
Arrays.parallelSort(distances);
} else {
Arrays.sort(distances);
}
executor.shutdown(); //销毁执行器
Map<String, Integer> results = new HashMap<>();
for(int i = 0; i < k; i++) {
Sample localExample = dataSet.get(distances[i].getIndex());
String tag = localExample.getTag();
results.merge(tag, 1, (a,b) ->a+b);
}
return Collections.max(results.entrySet(),
Map.Entry.comparingByValue()).getKey();
}
}
所以,该类中最关键的地方就是IndividualDistanceTask类,该类将输入范例和训练数据集距离计算作为并发任务,如下所示
import com.Knnclassifier.Distance;
import com.Knnclassifier.EuclideanDistanceCalculator;
import com.Knnclassifier.Sample;
import java.util.concurrent.CountDownLatch;
public class IndividualDistanceTask implements Runnable {
private final Distance[] distances;
private final int index;
private final Sample localSample;
private final Sample sample;
private final CountDownLatch endController;
public IndividualDistanceTask(Distance[] distances, int index,
Sample localSample, Sample sample,
CountDownLatch endController) {
this.distances = distances;
this.index = index;
this.localSample = localSample;
this.sample = sample;
this.endController = endController;
}
@Override
public void run() {
distances[index] = new Distance();
distances[index].setIndex(index);
distances[index].setDistance(EuclideanDistanceCalculator
.calculate(localSample, sample));
// System.out.println("count is "+endController.getCount());
endController.countDown(); //技术器 - 1
}
}
其中DataSet和Sample以及Distance类的设计可以参考上一篇博客
启动类
import com.Knnclassifier.Parallel.KnnClassifierParallelIndividual;
import java.io.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class Main {
public static void main(String[] args) throws IOException {
File file = new File("/home/hadoop/下载/bank-full.csv");
BufferedReader br = new BufferedReader(new FileReader(file));
String str = "";
Set<String> s = new HashSet<>();
List<BankMarketing> bankList = new ArrayList<>();
boolean head =true;
while((str = br.readLine())!=null) {
if(head) {
head = false;
continue;
}
BankMarketing bankMarketing = new BankMarketing(str);
bankList.add(bankMarketing);
}
br.close();
System.out.println(bankList.get(0).toString());
SingleKnn singleKnn = new SingleKnn(bankList, 2);
Sample samplle = new Sample();
double[] a = {30.0, 1.0, 1787.0, 0.0, 79.0, 1.0};
samplle.setExample(a);
KnnClassifierParallelIndividual knnClassifierParallelIndividual = new KnnClassifierParallelIndividual(bankList, 2,2, true);
long start = System.currentTimeMillis();
try {
samplle.setTag(knnClassifierParallelIndividual.classify(samplle));
} catch (Exception e) {
e.printStackTrace();
}
long end = System.currentTimeMillis();
System.out.println(samplle.getTag() + " 耗时 "+(end - start));
}
}
执行效果如下