多线程之K-近邻算法(一)

多线程之K-近邻算法

  之前曾经写过在MapReduce和Spark下的k近邻算法,最近在看多线程,于是也就顺便写写多线程下的k近邻算法,k近邻算法的详细原理就不写了,可以看我之前的博客
MapReduce之KNN算法

  在这里直接使用UCI机器学习资源库中的Bank Marketing数据集就可(忘记有这么个玩意了),在这里依旧采用欧式距离作为指标。在这篇博文中使用Bank Marketing数据集中age,education,balance,house,duration,campaign,married这几个属性来进行测试,通过一连串的测试数据来验证判断married这个参数
  在这里由于一些参数是字符串形式,在这里没必要使用词向量了,映射成一个浮点型的参数就可以
  先写主要实现类KnnClassifier,该类存储了训练数据集和数值k(用于确定某个实例标签的数量)

import java.util.*;

public class SingleKnn {
    private final List <? extends Sample> dataSet; //存储训练数据集
    private int k; //存储标签的范围数量

    public SingleKnn(List <? extends Sample> dataSet, int k) {
        this.dataSet = dataSet;
        this.k = k;
    }

    public String classify(Sample sample) {
        Distance[] distances = new Distance[dataSet.size()];
        int index = 0;

        for(Sample localExample : dataSet) {
 //           System.out.println("当前正在指定第 "+index+" 行计算");
            distances[index] = new Distance();
            distances[index].setIndex(index);
            distances[index].setDistance(EuclideanDistanceCalculator
                    .calculate(localExample, sample));
            index ++;
        }

        Arrays.sort(distances); //按照欧式距离的计算结果进行排序

        Map<String, Integer> results = new HashMap<>();
        for(int i = 0; i < k; i++) {  //取距离最小的前k个
            Sample localExample = dataSet.get(distances[i].getIndex());
            String tag = localExample.getTag();
            results.merge(tag, 1, (a,b) ->a+b);  //对前k个的标签进行统计
        }
        return  Collections.max(results.entrySet(),
                Map.Entry.comparingByValue()).getKey(); //取value最大的key
    }
}

由上述的一段代码可以看出,我们需要定义三个类,Sample,Distance和计算欧式距离的辅助类EuclideanDistanceCalculator
  先编写计算欧式距离的辅助类

package com.Knnclassifier;

public class EuclideanDistanceCalculator {
    public static double calculate (Sample example1, Sample example2) {
        double ret = 0.0d;
        double[] data1 = example1.getExample();
        double[] data2 = example2.getExample();

        if(data1.length != data2.length) {
            throw new IllegalArgumentException("Vector doesn't have the same length");
        }

        for(int i=0; i<data1.length; i++) {
            ret = Math.pow(data1[i] - data2[i],2);
        }
        return Math.sqrt(ret);
    }
}

  好,现在我们需要编写类Sample,在可以考虑编写一个Sample的扩展类,以便扩展更过的的功能
Sample类如下

package com.Knnclassifier;

public class Sample {
    private double[] example;
    private String tag;

    public double[] getExample() {
        return example;
    }

    public void setExample(double[] example) {
        this.example = example;
    }

    public String getTag() {
        return tag;
    }

    public void setTag(String tag) {
        this.tag = tag;
    }
}

扩展类BankMarketing
毕竟只是学习如何使用多线程,没必要处理全部的属性,调几个重要的即可

package com.Knnclassifier;

import java.util.HashMap;
import java.util.Map;

public class BankMarketing extends Sample {
    public static final Integer AGE = 0;  //文件切割后所在的位置
    public static final Integer EDUCATION = 3;
    public static final Integer BALANCE = 5;
    public static final Integer HOUSING = 6;
    public static final Integer DURATIOn = 11;
    public static final Integer CAMPAIGN = 12;
    public static final Integer MARRIED = 2;
    public static final Map<String, Double> map = new HashMap<String, Double>(){{
        put("\"tertiary\"",0d);
        put("\"primary\"", 1d);
        put("\"secondary\"", 2d);
        put("\"unknown\"", 3d);
        put("\"no\"",0d);
        put("\"yes\"", 1d);
    }};

    private String age;
    private String education;
    private String balance;
    private String house;
    private String duration;
    private String campaign;

    public String getCampaign() {
        return campaign;
    }

    public void setCampaign(String campaign) {
        this.campaign = campaign;
    }

    public String getAge() {
        return age;
    }

    public void setAge(String age) {
        this.age = age;
    }

    public String getEducation() {
        return education;
    }

    public void setEducation(String education) {
        this.education = education;
    }

    public String getBalance() {
        return balance;
    }

    public void setBalance(String balance) {
        this.balance = balance;
    }

    public String getHouse() {
        return house;
    }

    public void setHouse(String house) {
        this.house = house;
    }

    public String getDuration() {
        return duration;
    }

    public void setDuration(String duration) {
        this.duration = duration;
    }

    public BankMarketing(String input) {  //生成双精度浮点数组
        if(input == null || input.length() == 0) {
            return ;
        }
        String[] args = input.split(";");
        this.setAge(args[AGE]);
        this.setEducation(args[EDUCATION]);
        this.setBalance(args[BALANCE]);
        this.setHouse(args[HOUSING]);
        this.setDuration(args[DURATIOn]);
        this.setCampaign(args[CAMPAIGN]);
        double[] example = new double[6];
        example[0] = Double.parseDouble(this.getAge());
        example[1] = map.get(this.getEducation());
        example[2] = Double.parseDouble(this.getBalance());
        example[3] = map.get(this.getHouse());
        example[4] = Double.parseDouble(this.getDuration());
        example[5] = Double.parseDouble(this.getCampaign());
        this.setExample(example);
        this.setTag(args[MARRIED]);
    }

    @Override
    public String toString() {
        return "BankMarketing{" +
                "age='" + age + '\'' +
                ", education='" + education + '\'' +
                ", balance='" + balance + '\'' +
                ", house='" + house + '\'' +
                ", duration='" + duration + '\'' +
                ", campaign='" + campaign + '\'' +
                ", tag=‘" + this.getTag() + '\'' +
                '}';
    }
}

好,现在我们还查一个距离类Distance

package com.Knnclassifier;

public class Distance implements Comparable<Distance> {
    private Integer index;

    public Integer getIndex() {
        return index;
    }

    public void setIndex(Integer index) {
        this.index = index;
    }

    public double getDistance() {
        return distance;
    }

    public void setDistance(double distance) {
        this.distance = distance;
    }

    private double distance;

    @Override 
    public int compareTo(Distance o) { //需要重写比较方法
        if (this.distance < o.getDistance())
            return -1;
        else if (this.distance > o.getDistance())
            return 2;
        return 0;
    }
}

写在编写测试类启动测试以下整体效果

package com.Knnclassifier;

import com.Knnclassifier.Parallel.KnnClassifierParallelIndividual;

import java.io.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class Main {
    public static void main(String[] args) throws IOException {
        File file = new File("/home/下载/bank-full.csv"); //读取文件
        BufferedReader br = new BufferedReader(new FileReader(file));
        String str = "";
        Set<String>  s = new HashSet<>();
        List<BankMarketing> bankList = new ArrayList<>();
        boolean head =true;
        while((str = br.readLine())!=null) {
            if(head) {
                head = false;
                continue;
            }
            BankMarketing bankMarketing = new BankMarketing(str);
            bankList.add(bankMarketing);
        }
        br.close();
        System.out.println(bankList.get(0).toString());

        SingleKnn singleKnn = new SingleKnn(bankList, 2);
        Sample samplle = new Sample();
        double[] a = {30.0, 1.0, 1787.0, 0.0, 79.0, 1.0};  //懒得转译了,直接写成double类型数组了
        samplle.setExample(a);
        long start = System.currentTimeMillis();
        samplle.setTag(singleKnn.classify(samplle));
        long end = System.currentTimeMillis();
        System.out.println(samplle.getTag() + " 耗时 "+ (end-start));
    }
}

执行以下,看看结果
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值