
编程集体智能 (PCI)的第8章介绍了k最近邻居算法的用法和实现。 (k-NN)。


k-NN是一种分类算法,它使用(k)作为邻居数来确定某项将属于哪个类别。 为了确定要使用的邻居,算法使用距离/ 相似性得分函数,在此示例中为(欧氏距离)。

PCI在某些情况下可以进一步提高准确性。 这包括使用邻居的加权平均值,然后在优化技术的基础上使用模拟退火或遗传算法来确定最佳权重–模拟退火和遗传算法 (与之前的所有章节一样,代码位于我的github仓库)。


package net.briandupreez.pci.chapter8;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class EuclideanDistanceScore {

     * Determine distance between list of points.
     * @param list1 first list
     * @param list2 second list
     * @return distance between the two lists between 0 and 1... 0 being identical.
    public static double distanceList(final List<Double> list1, final List<Double> list2) {
        if (list1.size() != list2.size()) {
            throw new RuntimeException("Same number of values required.");

        double sumOfAllSquares = 0;
        for (int i = 0; i < list1.size(); i++) {
            sumOfAllSquares += Math.pow(list2.get(i) - list1.get(i), 2);
        return Math.sqrt(sumOfAllSquares);


package net.briandupreez.pci.chapter8;

import org.javatuples.Pair;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;

 * Created with IntelliJ IDEA.
 * User: bdupreez
 * Date: 2013/07/05
 * Time: 9:08 PM
public class Optimization {

    public List<Pair<Integer, Integer>> createDomain() {

        final List<Pair<Integer, Integer>> domain = new ArrayList<>(4);
        for (int i = 0; i < 4; i++) {
            final Pair<Integer, Integer> pair = new Pair<>(0, 10);
        return domain;

     * Simulated Annealing
     * @param domain list of tuples with min and max
     * @return (global minimum)
    public Double[] simulatedAnnealing(final List<Pair<Integer, Integer>> domain, final double startingTemp, final double cool, final int step) {

        double temp = startingTemp;
        //create random
        Double[] sol = new Double[domain.size()];
        Random random = new Random();
        for (int r = 0; r < domain.size(); r++) {
            sol[r] = Double.valueOf(random.nextInt(19));

        while (temp > 0.1) {
            //pick a random indices
            int i = random.nextInt(domain.size() - 1);

            //pick a directions + or -
            int direction = random.nextInt(step) % 2 == 0 ? -(random.nextInt(step)) : random.nextInt(1);

            Double[] cloneSolr = sol.clone();
            cloneSolr[i] += direction;
            if (cloneSolr[i] < domain.get(i).getValue0()) {
                cloneSolr[i] = Double.valueOf(domain.get(i).getValue0());
            } else if (cloneSolr[i] > domain.get(i).getValue1()) {
                cloneSolr[i] = Double.valueOf(domain.get(i).getValue1());

            //calc current and new cost
            double currentCost = scheduleCost(sol);
            double newCost = scheduleCost(cloneSolr);
            System.out.println("Current: " + currentCost + " New: " + newCost);

            double probability = Math.pow(Math.E, -(newCost - currentCost) / temp);

            // Is it better, or does it make the probability cutoff?
            if (newCost < currentCost || Math.random() < probability) {
                sol = cloneSolr;
            temp = temp * cool;
        return sol;

    public double scheduleCost(Double[] sol) {
        NumPredict numPredict = new NumPredict();
        final List<Map<String,List<Double>>> rescale = numPredict.rescale(numPredict.createWineSet2(), Arrays.asList(sol));
        return numPredict.crossValidate(rescale,0.1,100);

    public Double[] geneticAlgorithm(final List<Pair<Integer, Integer>> domain, final int populationSize,
                                     final int step, final double elite, final int maxIter, final double mutProb) {

        List<Double[]> pop = createPopulation(domain.size(), populationSize);

        final int topElite = new Double(elite * populationSize).intValue();

        final SortedMap<Double, Double[]> scores = new TreeMap<>();
        for (int i = 0; i < maxIter; i++) {
            for (final Double[] run : pop) {
                scores.put(scheduleCost(run), run);
            pop = determineElite(topElite, scores);

            while (pop.size() < populationSize) {

                final Random random = new Random();
                if (Math.random() < mutProb) {

                    final int ran = random.nextInt(topElite);
                    pop.add(mutate(domain, pop.get(ran), step));

                } else {
                    final int ran1 = random.nextInt(topElite);
                    final int ran2 = random.nextInt(topElite);
                    pop.add(crossover(pop.get(ran1), pop.get(ran2), domain.size()));

        return scores.entrySet().iterator().next().getValue();


     * Grab the elites
     * @param topElite how many
     * @param scores   sorted on score
     * @return best ones
    private List<Double[]> determineElite(int topElite, SortedMap<Double, Double[]> scores) {

        Double toKey = null;
        int index = 0;
        for (final Double key : scores.keySet()) {
            if (index++ == topElite) {
                toKey = key;
        scores = scores.headMap(toKey);
        return new ArrayList<>(scores.values());

     * Create a population
     * @param arraySize the array size
     * @param popSize   the population size
     * @return a random population
    private List<Double[]> createPopulation(final int arraySize, final int popSize) {
        final List<Double[]> returnList = new ArrayList<>();
        for (int i = 0; i < popSize; i++) {
            Double[] sol = new Double[arraySize];
            Random random = new Random();
            for (int r = 0; r < arraySize; r++) {
                sol[r] = Double.valueOf(random.nextInt(8));
        return returnList;

     * Mutate a value.
     * @param domain the domain
     * @param vec    the data to be mutated
     * @param step   the step
     * @return mutated array
    private Double[] mutate(final List<Pair<Integer, Integer>> domain, final Double[] vec, final int step) {
        final Random random = new Random();
        int i = random.nextInt(domain.size() - 1);
        Double[] retArr = vec.clone();
        if (Math.random() < 0.5 && (vec[1] - step) > domain.get(i).getValue0()) {
            retArr[i] -= step;
        } else if (vec[i] + step < domain.get(i).getValue1()) {
            retArr[i] += step;
        return vec;

     * Cross over parts of each array
     * @param arr1 array 1
     * @param arr2 array 2
     * @param max  max value
     * @return new array
    private Double[] crossover(final Double[] arr1, final Double[] arr2, final int max) {
        final Random random = new Random();
        int i = random.nextInt(max);
        return concatArrays(Arrays.copyOfRange(arr1, 0, i), Arrays.copyOfRange(arr2, i, arr2.length));

     * Concat 2 arrays
     * @param first  first
     * @param second second
     * @return new combined array
    private Double[] concatArrays(final Double[] first, final Double[] second) {
        Double[] result = Arrays.copyOf(first, first.length + second.length);
        System.arraycopy(second, 0, result, first.length, second.length);
        return result;


package net.briandupreez.pci.chapter8;

import org.javatuples.Pair;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;

 * NumPredict.
 * User: bdupreez
 * Date: 2013/08/12
 * Time: 8:29 AM
public class NumPredict {

     * Determine the wine price.
     * @param rating rating
     * @param age    age
     * @return the price
    public double winePrice(final double rating, final double age) {
        final double peakAge = rating - 50;

        //Calculate the price based on rating
        double price = rating / 2;

        if (age > peakAge) {
            //goes bad in 10 years
            price *= 5 - (age - peakAge) / 2;
        } else {
            //increases as it reaches its peak
            price *= 5 * ((age + 1)) / peakAge;

        if (price < 0) {
            price = 0.0;

        return price;

     * Data Generator
     * @return data
    public List<Map<String, List<Double>>> createWineSet1() {

        final List<Map<String, List<Double>>> wineList = new ArrayList<>();
        for (int i = 0; i < 300; i++) {

            double rating = Math.random() * 50 + 50;
            double age = Math.random() * 50;

            double price = winePrice(rating, age);
            price *= (Math.random() * 0.2 + 0.9);

            final Map<String, List<Double>> map = new HashMap<>();
            final List<Double> input = new LinkedList<>();
            map.put("input", input);
            final List<Double> result = new ArrayList();
            map.put("result", result);


        return wineList;

     * Data Generator
     * @return data
    public List<Map<String, List<Double>>> createWineSet2() {

        final List<Map<String, List<Double>>> wineList = new ArrayList<>();
        for (int i = 0; i < 300; i++) {

            double rating = Math.random() * 50 + 50;
            double age = Math.random() * 50;
            final Random random = new Random();
            double aisle = (double) random.nextInt(20);
            double[] sizes = new double[]{375.0, 750.0, 1500.0};
            double bottleSize = sizes[random.nextInt(3)];

            double price = winePrice(rating, age);
            price *= (bottleSize / 750);
            price *= (Math.random() * 0.2 + 0.9);

            final Map<String, List<Double>> map = new HashMap<>();
            final List<Double> input = new LinkedList<>();
            map.put("input", input);
            final List<Double> result = new ArrayList();
            map.put("result", result);


        return wineList;

     * Rescale
     * @param data  data
     * @param scale the scales
     * @return scaled data
    public List<Map<String, List<Double>>> rescale(final List<Map<String, List<Double>>> data, final List<Double> scale) {
        final List<Map<String, List<Double>>> scaledData = new ArrayList<>();
        for (final Map<String, List<Double>> dataItem : data) {
            final List<Double> scaledList = new LinkedList<>();
            for (int i = 0; i < scale.size(); i++) {
                scaledList.add(scale.get(i) * dataItem.get("input").get(i));
            dataItem.put("input", scaledList);

        return scaledData;

     * Determine all the distances from a list
     * @param data all the data
     * @param vec1 one list
     * @return all the distances
    public List<Pair<Double, Integer>> determineDistances(final List<Map<String, List<Double>>> data, final List<Double> vec1) {
        final List<Pair<Double, Integer>> distances = new ArrayList<>();
        int i = 1;
        for (final Map<String, List<Double>> map : data) {
            final List<Double> vec2 = map.get("input");

            distances.add(new Pair(EuclideanDistanceScore.distanceList(vec1, vec2), i++));


        return distances;

     * Use kNN to estimate a new price
     * @param data all the data
     * @param vec1 new fields to price
     * @param k    the amount of neighbours
     * @return the estimated price
    public double knnEstimate(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double avg = 0.0;

        for (int i = 0; i <= k; i++) {
            int idx = distances.get(i).getValue1();
            avg += data.get(idx - 1).get("result").get(0);
        avg = avg / k;
        return avg;

     * KNN using a weighted average of the neighbours
     * @param data the dataset
     * @param vec1 the data to price
     * @param k    number of neighbours
     * @return the weighted price
    public double weightedKnn(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double avg = 0.0;
        double totalWeight = 0.0;

        for (int i = 0; i <= k; i++) {
            double dist = distances.get(i).getValue0();
            int idx = distances.get(i).getValue1();
            double weight = guassianWeight(dist, 5.0);

            avg += weight * data.get(idx - 1).get("result").get(0);
            totalWeight += weight;
        if (totalWeight == 0.0) {
            return 0.0;
        avg = avg / totalWeight;

        return avg;

     * Gaussian Weight function, smoother weight curve that doesnt go to 0
     * @param distance the distance
     * @param sigma    sigma
     * @return weighted value
    public double guassianWeight(final double distance, final double sigma) {
        double alteredDistance = -(Math.pow(distance, 2));
        double sigmaSize = (2 * Math.pow(sigma, 2));
        return Math.pow(Math.E, (alteredDistance / sigmaSize));

     * Split the data for cross validation.
     * @param data        the data to split
     * @param testPercent % of data to use for the tests
     * @return a tuple 0 - training, 1 - test
    public Pair<List, List> divideData(final List<Map<String, List<Double>>> data, final double testPercent) {
        final List trainingList = new ArrayList();
        final List testList = new ArrayList();

        for (final Map<String, List<Double>> dataItem : data) {
            if (Math.random() < testPercent) {
            } else {

        return new Pair(trainingList, testList);

     * Test result and squares the differences to make it more obvious
     * @param trainingSet the training set
     * @param testSet     the test set
     * @return the error
    public double testAlgorithm(final List trainingSet, final List testSet) {
        double error = 0.0;
        final List<Map<String, List<Double>>> typedSet = (List<Map<String, List<Double>>>) testSet;
        for (final Map<String, List<Double>> testData : typedSet) {
            double guess = weightedKnn(trainingSet, testData.get("input"), 3);
            error += Math.pow((testData.get("result").get(0) - guess), 2);
        return error / testSet.size();

     * This runs iterations of the test, and returns an averaged score
     * @param data        the data
     * @param testPercent % test
     * @param trials      number of iterations
     * @return result
    public double crossValidate(final List<Map<String, List<Double>>> data, final double testPercent, final int trials) {

        double error = 0.0;
        for (int i = 0; i < trials; i++) {
            final Pair<List, List> trainingPair = divideData(data, testPercent);
            error += testAlgorithm(trainingPair.getValue0(), trainingPair.getValue1());
        return error / trials;

     * Gives the probability that an item is in a price range between 0 and 1
     * Adds up the neighbours weightd and divides it by the total
     * @param data the data
     * @param vec1 the input
     * @param k    the number of neighbours
     * @param low  low amount of range
     * @param high the high amount
     * @return probability between 0 and 1
    public double probabilityGuess(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k,
                                   final double low, final double high) {

        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double neighbourWeights = 0.0;
        double totalWeights = 0.0;

        for (int i = 0; i < k; i++) {
            double dist = distances.get(i).getValue0();
            int index = distances.get(i).getValue1();
            double weight = guassianWeight(dist, 5);
            final List<Double> result = data.get(index).get("result");
            double v = result.get(0);

            //check if the point is in the range.
            if (v >= low && v <= high) {
                neighbourWeights += weight;
            totalWeights += weight;
        if (totalWeights == 0) {
            return 0;
        return neighbourWeights / totalWeights;



翻译自: https://www.javacodegeeks.com/2013/08/creating-a-price-model-using-k-nearest-neighbours-genetic-algorithm.html





