编程集体智能 (PCI)的第8章介绍了k最近邻居算法的用法和实现。 (k-NN)。
简单的说:
k-NN是一种分类算法,它使用(k)作为邻居数来确定某项将属于哪个类别。 为了确定要使用的邻居,算法使用距离/ 相似性得分函数,在此示例中为(欧氏距离)。
PCI在某些情况下可以进一步提高准确性。 这包括使用邻居的加权平均值,然后在优化技术的基础上使用模拟退火或遗传算法来确定最佳权重–模拟退火和遗传算法 (与之前的所有章节一样,代码位于我的github仓库)。
因此,相似性得分函数看起来像(与之前使用的函数略有不同,如果相等,则将其取反以返回1):
package net.briandupreez.pci.chapter8;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class EuclideanDistanceScore {
/**
* Determine distance between list of points.
*
* @param list1 first list
* @param list2 second list
* @return distance between the two lists between 0 and 1... 0 being identical.
*/
public static double distanceList(final List<Double> list1, final List<Double> list2) {
if (list1.size() != list2.size()) {
throw new RuntimeException("Same number of values required.");
}
double sumOfAllSquares = 0;
for (int i = 0; i < list1.size(); i++) {
sumOfAllSquares += Math.pow(list2.get(i) - list1.get(i), 2);
}
return Math.sqrt(sumOfAllSquares);
}
}
我最初使用Ints进行实现时更新了模拟退火和遗传算法代码(使用ML或AI进行任何操作时的经验教训,坚持加倍)。
package net.briandupreez.pci.chapter8;
import org.javatuples.Pair;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;
/**
* Created with IntelliJ IDEA.
* User: bdupreez
* Date: 2013/07/05
* Time: 9:08 PM
*/
public class Optimization {
public List<Pair<Integer, Integer>> createDomain() {
final List<Pair<Integer, Integer>> domain = new ArrayList<>(4);
for (int i = 0; i < 4; i++) {
final Pair<Integer, Integer> pair = new Pair<>(0, 10);
domain.add(pair);
}
return domain;
}
/**
* Simulated Annealing
*
* @param domain list of tuples with min and max
* @return (global minimum)
*/
public Double[] simulatedAnnealing(final List<Pair<Integer, Integer>> domain, final double startingTemp, final double cool, final int step) {
double temp = startingTemp;
//create random
Double[] sol = new Double[domain.size()];
Random random = new Random();
for (int r = 0; r < domain.size(); r++) {
sol[r] = Double.valueOf(random.nextInt(19));
}
while (temp > 0.1) {
//pick a random indices
int i = random.nextInt(domain.size() - 1);
//pick a directions + or -
int direction = random.nextInt(step) % 2 == 0 ? -(random.nextInt(step)) : random.nextInt(1);
Double[] cloneSolr = sol.clone();
cloneSolr[i] += direction;
if (cloneSolr[i] < domain.get(i).getValue0()) {
cloneSolr[i] = Double.valueOf(domain.get(i).getValue0());
} else if (cloneSolr[i] > domain.get(i).getValue1()) {
cloneSolr[i] = Double.valueOf(domain.get(i).getValue1());
}
//calc current and new cost
double currentCost = scheduleCost(sol);
double newCost = scheduleCost(cloneSolr);
System.out.println("Current: " + currentCost + " New: " + newCost);
double probability = Math.pow(Math.E, -(newCost - currentCost) / temp);
// Is it better, or does it make the probability cutoff?
if (newCost < currentCost || Math.random() < probability) {
sol = cloneSolr;
}
temp = temp * cool;
}
return sol;
}
public double scheduleCost(Double[] sol) {
NumPredict numPredict = new NumPredict();
final List<Map<String,List<Double>>> rescale = numPredict.rescale(numPredict.createWineSet2(), Arrays.asList(sol));
return numPredict.crossValidate(rescale,0.1,100);
}
public Double[] geneticAlgorithm(final List<Pair<Integer, Integer>> domain, final int populationSize,
final int step, final double elite, final int maxIter, final double mutProb) {
List<Double[]> pop = createPopulation(domain.size(), populationSize);
final int topElite = new Double(elite * populationSize).intValue();
final SortedMap<Double, Double[]> scores = new TreeMap<>();
for (int i = 0; i < maxIter; i++) {
for (final Double[] run : pop) {
scores.put(scheduleCost(run), run);
}
pop = determineElite(topElite, scores);
while (pop.size() < populationSize) {
final Random random = new Random();
if (Math.random() < mutProb) {
final int ran = random.nextInt(topElite);
pop.add(mutate(domain, pop.get(ran), step));
} else {
final int ran1 = random.nextInt(topElite);
final int ran2 = random.nextInt(topElite);
pop.add(crossover(pop.get(ran1), pop.get(ran2), domain.size()));
}
}
System.out.println(scores);
}
return scores.entrySet().iterator().next().getValue();
}
/**
* Grab the elites
*
* @param topElite how many
* @param scores sorted on score
* @return best ones
*/
private List<Double[]> determineElite(int topElite, SortedMap<Double, Double[]> scores) {
Double toKey = null;
int index = 0;
for (final Double key : scores.keySet()) {
if (index++ == topElite) {
toKey = key;
break;
}
}
scores = scores.headMap(toKey);
return new ArrayList<>(scores.values());
}
/**
* Create a population
*
* @param arraySize the array size
* @param popSize the population size
* @return a random population
*/
private List<Double[]> createPopulation(final int arraySize, final int popSize) {
final List<Double[]> returnList = new ArrayList<>();
for (int i = 0; i < popSize; i++) {
Double[] sol = new Double[arraySize];
Random random = new Random();
for (int r = 0; r < arraySize; r++) {
sol[r] = Double.valueOf(random.nextInt(8));
}
returnList.add(sol);
}
return returnList;
}
/**
* Mutate a value.
*
* @param domain the domain
* @param vec the data to be mutated
* @param step the step
* @return mutated array
*/
private Double[] mutate(final List<Pair<Integer, Integer>> domain, final Double[] vec, final int step) {
final Random random = new Random();
int i = random.nextInt(domain.size() - 1);
Double[] retArr = vec.clone();
if (Math.random() < 0.5 && (vec[1] - step) > domain.get(i).getValue0()) {
retArr[i] -= step;
} else if (vec[i] + step < domain.get(i).getValue1()) {
retArr[i] += step;
}
return vec;
}
/**
* Cross over parts of each array
*
* @param arr1 array 1
* @param arr2 array 2
* @param max max value
* @return new array
*/
private Double[] crossover(final Double[] arr1, final Double[] arr2, final int max) {
final Random random = new Random();
int i = random.nextInt(max);
return concatArrays(Arrays.copyOfRange(arr1, 0, i), Arrays.copyOfRange(arr2, i, arr2.length));
}
/**
* Concat 2 arrays
*
* @param first first
* @param second second
* @return new combined array
*/
private Double[] concatArrays(final Double[] first, final Double[] second) {
Double[] result = Arrays.copyOf(first, first.length + second.length);
System.arraycopy(second, 0, result, first.length, second.length);
return result;
}
}
然后最后将所有这些整合到我的PCI示例的Java实现中
package net.briandupreez.pci.chapter8;
import org.javatuples.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* NumPredict.
* User: bdupreez
* Date: 2013/08/12
* Time: 8:29 AM
*/
public class NumPredict {
/**
* Determine the wine price.
*
* @param rating rating
* @param age age
* @return the price
*/
public double winePrice(final double rating, final double age) {
final double peakAge = rating - 50;
//Calculate the price based on rating
double price = rating / 2;
if (age > peakAge) {
//goes bad in 10 years
price *= 5 - (age - peakAge) / 2;
} else {
//increases as it reaches its peak
price *= 5 * ((age + 1)) / peakAge;
}
if (price < 0) {
price = 0.0;
}
return price;
}
/**
* Data Generator
*
* @return data
*/
@SuppressWarnings("unchecked")
public List<Map<String, List<Double>>> createWineSet1() {
final List<Map<String, List<Double>>> wineList = new ArrayList<>();
for (int i = 0; i < 300; i++) {
double rating = Math.random() * 50 + 50;
double age = Math.random() * 50;
double price = winePrice(rating, age);
price *= (Math.random() * 0.2 + 0.9);
final Map<String, List<Double>> map = new HashMap<>();
final List<Double> input = new LinkedList<>();
input.add(rating);
input.add(age);
map.put("input", input);
final List<Double> result = new ArrayList();
result.add(price);
map.put("result", result);
wineList.add(map);
}
return wineList;
}
/**
* Data Generator
*
* @return data
*/
@SuppressWarnings("unchecked")
public List<Map<String, List<Double>>> createWineSet2() {
final List<Map<String, List<Double>>> wineList = new ArrayList<>();
for (int i = 0; i < 300; i++) {
double rating = Math.random() * 50 + 50;
double age = Math.random() * 50;
final Random random = new Random();
double aisle = (double) random.nextInt(20);
double[] sizes = new double[]{375.0, 750.0, 1500.0};
double bottleSize = sizes[random.nextInt(3)];
double price = winePrice(rating, age);
price *= (bottleSize / 750);
price *= (Math.random() * 0.2 + 0.9);
final Map<String, List<Double>> map = new HashMap<>();
final List<Double> input = new LinkedList<>();
input.add(rating);
input.add(age);
input.add(aisle);
input.add(bottleSize);
map.put("input", input);
final List<Double> result = new ArrayList();
result.add(price);
map.put("result", result);
wineList.add(map);
}
return wineList;
}
/**
* Rescale
*
* @param data data
* @param scale the scales
* @return scaled data
*/
public List<Map<String, List<Double>>> rescale(final List<Map<String, List<Double>>> data, final List<Double> scale) {
final List<Map<String, List<Double>>> scaledData = new ArrayList<>();
for (final Map<String, List<Double>> dataItem : data) {
final List<Double> scaledList = new LinkedList<>();
for (int i = 0; i < scale.size(); i++) {
scaledList.add(scale.get(i) * dataItem.get("input").get(i));
}
dataItem.put("input", scaledList);
scaledData.add(dataItem);
}
return scaledData;
}
/**
* Determine all the distances from a list
*
* @param data all the data
* @param vec1 one list
* @return all the distances
*/
public List<Pair<Double, Integer>> determineDistances(final List<Map<String, List<Double>>> data, final List<Double> vec1) {
final List<Pair<Double, Integer>> distances = new ArrayList<>();
int i = 1;
for (final Map<String, List<Double>> map : data) {
final List<Double> vec2 = map.get("input");
distances.add(new Pair(EuclideanDistanceScore.distanceList(vec1, vec2), i++));
}
Collections.sort(distances);
return distances;
}
/**
* Use kNN to estimate a new price
*
* @param data all the data
* @param vec1 new fields to price
* @param k the amount of neighbours
* @return the estimated price
*/
public double knnEstimate(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
double avg = 0.0;
for (int i = 0; i <= k; i++) {
int idx = distances.get(i).getValue1();
avg += data.get(idx - 1).get("result").get(0);
}
avg = avg / k;
return avg;
}
/**
* KNN using a weighted average of the neighbours
*
* @param data the dataset
* @param vec1 the data to price
* @param k number of neighbours
* @return the weighted price
*/
public double weightedKnn(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
double avg = 0.0;
double totalWeight = 0.0;
for (int i = 0; i <= k; i++) {
double dist = distances.get(i).getValue0();
int idx = distances.get(i).getValue1();
double weight = guassianWeight(dist, 5.0);
avg += weight * data.get(idx - 1).get("result").get(0);
totalWeight += weight;
}
if (totalWeight == 0.0) {
return 0.0;
}
avg = avg / totalWeight;
return avg;
}
/**
* Gaussian Weight function, smoother weight curve that doesnt go to 0
*
* @param distance the distance
* @param sigma sigma
* @return weighted value
*/
public double guassianWeight(final double distance, final double sigma) {
double alteredDistance = -(Math.pow(distance, 2));
double sigmaSize = (2 * Math.pow(sigma, 2));
return Math.pow(Math.E, (alteredDistance / sigmaSize));
}
/**
* Split the data for cross validation.
*
* @param data the data to split
* @param testPercent % of data to use for the tests
* @return a tuple 0 - training, 1 - test
*/
@SuppressWarnings("unchecked")
public Pair<List, List> divideData(final List<Map<String, List<Double>>> data, final double testPercent) {
final List trainingList = new ArrayList();
final List testList = new ArrayList();
for (final Map<String, List<Double>> dataItem : data) {
if (Math.random() < testPercent) {
testList.add(dataItem);
} else {
trainingList.add(dataItem);
}
}
return new Pair(trainingList, testList);
}
/**
* Test result and squares the differences to make it more obvious
*
* @param trainingSet the training set
* @param testSet the test set
* @return the error
*/
@SuppressWarnings("unchecked")
public double testAlgorithm(final List trainingSet, final List testSet) {
double error = 0.0;
final List<Map<String, List<Double>>> typedSet = (List<Map<String, List<Double>>>) testSet;
for (final Map<String, List<Double>> testData : typedSet) {
double guess = weightedKnn(trainingSet, testData.get("input"), 3);
error += Math.pow((testData.get("result").get(0) - guess), 2);
}
return error / testSet.size();
}
/**
* This runs iterations of the test, and returns an averaged score
*
* @param data the data
* @param testPercent % test
* @param trials number of iterations
* @return result
*/
public double crossValidate(final List<Map<String, List<Double>>> data, final double testPercent, final int trials) {
double error = 0.0;
for (int i = 0; i < trials; i++) {
final Pair<List, List> trainingPair = divideData(data, testPercent);
error += testAlgorithm(trainingPair.getValue0(), trainingPair.getValue1());
}
return error / trials;
}
/**
* Gives the probability that an item is in a price range between 0 and 1
* Adds up the neighbours weightd and divides it by the total
*
* @param data the data
* @param vec1 the input
* @param k the number of neighbours
* @param low low amount of range
* @param high the high amount
* @return probability between 0 and 1
*/
public double probabilityGuess(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k,
final double low, final double high) {
final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
double neighbourWeights = 0.0;
double totalWeights = 0.0;
for (int i = 0; i < k; i++) {
double dist = distances.get(i).getValue0();
int index = distances.get(i).getValue1();
double weight = guassianWeight(dist, 5);
final List<Double> result = data.get(index).get("result");
double v = result.get(0);
//check if the point is in the range.
if (v >= low && v <= high) {
neighbourWeights += weight;
}
totalWeights += weight;
}
if (totalWeights == 0) {
return 0;
}
return neighbourWeights / totalWeights;
}
}
在阅读有关k-NN的更多信息时,我还偶然发现了以下博客文章:
- 第一个描述了使用k-NN的一些困难。: k-最近邻居–危险的简单
- 接下来是对k-NN的概述: k-NN算法的详细介绍