说明:算法来自于《集体智慧编程》-第五章
原书代码用 Python 实现,这两天看这章书,改用 Java 实现。
问题描述:Glass 一家六人在全国各地c,要到 LGA 碰头聚会。求花费最少的解法。
和原书代码意思不同的:计算增加了旅途中时间,0.5/h
/**
*
* FILENAME: Optimization.java
* AUTHOR: vivizhyy[at]gmail.com
* STUID: whu
* DATE: 2010-4-12
* USAGE :
*/
package ch5.optimization;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Random;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.joda.time.DateTime;
import org.joda.time.LocalTime;
public class Optimization {
private HashMap<String, String> people = new HashMap<String, String>();
private static String[] family = {"Seymour", "Pranny", "Zooey", "Wait", "Buddy", "Les"};
private String destination = "LGA";
private Flights flights = new Flights();
private static final int MEMBER_NUM = 6;
private Logger log = Logger.getLogger(Optimization.class.getName());
private void initPeople() {
this.people.put("Seymour", "BOS");
this.people.put("Pranny", "DAL");
this.people.put("Zooey", "CAK");
this.people.put("Wait", "MIA");
this.people.put("Buddy", "ORD");
this.people.put("Les", "OMA");
}
/**
*
* @param times
*/
public void printSchedule(int[] times) {
StringBuilder scheduleResult = new StringBuilder();
int index = 0;
for (String mem : family) {
scheduleResult.append(mem + "\t"
+ people.get(mem) + "\t");
Flights f = flights.getFlightByOriginAndDest(people.get(mem), destination)[times[index * 2]];
scheduleResult.append(f.getDepart() + "-" + f.getArraive() + "\t$"
+ f.getPrice() + "\t");
f = flights.getFlightByOriginAndDest(destination, people.get(mem))[times[index * 2 + 1]];
scheduleResult.append(f.getDepart() + "-" + f.getArraive() + "\t$"
+ f.getPrice() + "\n");
index++;
}
System.out.println(scheduleResult);
}
/**
*
* @param t
* @return
*/
public int getMinutes(LocalTime t) {
return (t.getMinuteOfHour() + t.getHourOfDay() * 60);
}
/**
*
* @param sol
* @return
*/
public double scheduleCost(int[] sol) {
PropertyConfigurator.configure("D:/Documents/NetBeansProjects/CollectiveProgramming/src/log4j.properties");
double totalPrice = 0.0;
int lastArrival = 0;
int earliestDep = 24 * 60;
int totalTravel = 0;
Flights[] outBound = new Flights[MEMBER_NUM * 2];
Flights[] returnFlight = new Flights[MEMBER_NUM * 2];
for (int i = 0; i < MEMBER_NUM; i++) {
//得到往返航班
outBound[i] = flights.getFlightByOriginAndDest(people.get(family[i]), destination)[sol[i * 2]];
returnFlight[i] = flights.getFlightByOriginAndDest(destination, people.get(family[i]))[sol[i * 2 + 1]];
//log.info("price:" + outBound[i].getPrice() + "\t" + returnFlight[i].getPrice());
//加航班价格
totalPrice += outBound[i].getPrice();
totalPrice += returnFlight[i].getPrice();
//加旅行时间
totalTravel += getMinutes(outBound[i].getArraive()) - getMinutes(outBound[i].getDepart());
totalTravel += getMinutes(returnFlight[i].getArraive()) - getMinutes(returnFlight[i].getDepart());
//记录最晚到达时间和最早离开时间
if (lastArrival < getMinutes(outBound[i].getArraive())) {
lastArrival = getMinutes(outBound[i].getArraive());
}
if (earliestDep > getMinutes(returnFlight[i].getDepart())) {
earliestDep = getMinutes(returnFlight[i].getDepart());
}
}
int totalWait = 0;
for (int i = 0; i < MEMBER_NUM; i++) {
totalWait += lastArrival - getMinutes(outBound[i].getArraive());
totalWait += getMinutes(returnFlight[i].getDepart()) - earliestDep;
}
//要多付一天的汽车租用金吗?
if (lastArrival > earliestDep) {
totalPrice += 50;
}
totalPrice = totalPrice + totalWait + totalTravel * 0.5;
return totalPrice;
}
/**
* 随机在 <code>loopTimes</code> 次中找出最值
* @param loopTimes
* @return
*/
public int[] randomOptimize(int loopTimes) {
int[] bestr = new int[MEMBER_NUM * 2];
double best = 999999999;
for (int i = 0; i < loopTimes; i++) {
int[] r = randomResult();
double price = scheduleCost(r);
if (best > price) {
best = price;
bestr = r;
}
}
System.out.println("total cost: " + best);
return bestr;
}
/**
* 爬山法
*
* @return
*/
public int[] hillclimb() {
int[] sol = randomResult();
double best = 999999999;
int count = 0;
while (true) {
count++;
int[][] neighbors = new int[MEMBER_NUM * 2][MEMBER_NUM * 2];
for (int j = 0; j < MEMBER_NUM * 2; j = j + 2) {
if (sol[j] > 0) {
for (int m = 0; m < MEMBER_NUM * 2; m++) {
if (m == j && sol[j] <= 9) {
neighbors[j][m] = sol[m] + 1;
} else {
neighbors[j][m] = sol[m];
}
}
}
if (sol[j] <= 9) {
for (int m = 0; m < MEMBER_NUM * 2; m++) {
if (m == j && sol[j] != 0) {
neighbors[j + 1][j] = sol[m] - 1;
} else {
neighbors[j + 1][m] = sol[m];
}
}
}
}
double currentCost = scheduleCost(sol);
for (int m = 0; m < MEMBER_NUM; m++) {
double cost = scheduleCost(neighbors[m]);
//System.out.println("cost: " + cost);
if (cost < best) {
best = cost;
sol = neighbors[m];
}
}
if (best == currentCost) {
System.out.println("best: " + best);
System.out.println("loop: " + count);
return sol;
}
}
}
/**
* 退火算法
*
* @param T
* @param cool
* @param step
* @return
*/
public int[] annealingoptimize(double T, double cool, int step) {
int[] vec = randomResult();
long seed = System.nanoTime();
Random random = new Random(seed);
while (T > 0.1) {
int index = random.nextInt(MEMBER_NUM * 2 - 1);
int dir = (int) (random.nextInt(step) * Math.pow(-1, random.nextInt()));
// System.out.println("dir: " + dir);
int[] vecb = new int[MEMBER_NUM * 2];
for (int i = 0; i < MEMBER_NUM * 2; i++) {
if (i == index) {
if (vec[i] + dir < 0) {
vecb[i] = 0;
continue;
}
if (vecb[i] > 9) {
vecb[i] = 9;
continue;
}
vecb[i] = vec[i] + dir;
} else {
vecb[i] = vec[i];
}
}
double ea = scheduleCost(vec);
double eb = scheduleCost(vecb);
if (eb < ea || random.nextDouble() < Math.pow(Math.E, -(eb - ea) / T)) {
vec = vecb;
}
T *= cool;
}
System.out.println(scheduleCost(vec));
return vec;
}
public int[] geneticoptimize(int popSize, int step, double mutprob, double elite, int maxiter) {
long seed = System.nanoTime();
Random random = new Random(seed);
//构造初始种群
ArrayList<int[]> pop = new ArrayList<int[]>(popSize);
for (int i = 0; i < popSize; i++) {
pop.add(randomResult());
}
//每一代胜出者数目
int topelite = (int) (elite * popSize);
ArrayList<Score> scores = new ArrayList<Score>();
for (int j = 0; j < maxiter; j++) {
for (int x = 0; x < popSize; x++) {
Score s = new Score();
s.list = pop.get(x);
s.price = scheduleCost(pop.get(x));
scores.add(s);
}
CompareScore cs = new CompareScore();
Collections.sort(scores, cs);
ArrayList<int[]> ranked = new ArrayList<int[]>();
for (int m = 0; m < popSize; m++) {
pop.remove(0);
ranked.add(scores.get(m).list);
}
for (int n = 0; n < topelite; n++) {
pop.add(ranked.get(n));
}
while (pop.size() < popSize) {
if (random.nextDouble() < mutprob) { //变异
int c = random.nextInt(topelite);
pop.add(mutate(ranked.get(c), step));
} else {
int c1 = random.nextInt(topelite);
int c2 = random.nextInt(topelite);
pop.add(crossOver(ranked.get(c1), ranked.get(c2)));
}
}
System.out.println(scores.get(0).price);
}
return scores.get(0).list;
}
/**
* 变异
*
* @param vec
* @param step
* @return
*/
public int[] mutate(int[] vec, int step) {
int[] mutateR = new int[MEMBER_NUM * 2];
long seed = System.nanoTime();
Random random = new Random(seed);
int index = random.nextInt(MEMBER_NUM * 2 - 1);
if (random.nextDouble() < 0.5 && vec[index] > 0) {
for (int i = 0; i < MEMBER_NUM * 2; i++) {
if (i == index && (vec[i] - step) > 0) {
mutateR[i] = vec[i] - step;
} else {
mutateR[i] = vec[i];
}
}
} else {
for (int i = 0; i < MEMBER_NUM * 2; i++) {
if (i == index && (vec[i] + step) < 9) {
mutateR[i] = vec[i] + step;
} else {
mutateR[i] = vec[i];
}
}
}
return mutateR;
}
/**
* 交叉
*
* @param r1
* @param r2
* @return
*/
public int[] crossOver(int[] r1, int[] r2) {
int[] crossOverR = new int[MEMBER_NUM * 2];
long seed = System.nanoTime();
Random random = new Random(seed);
int index = random.nextInt(MEMBER_NUM * 2 - 2);
for (int i = 0; i < MEMBER_NUM * 2; i++) {
if (i < index) {
crossOverR[i] = r1[i];
} else {
crossOverR[i] = r2[i];
}
}
return crossOverR;
}
/**
* 产生一个长度为 MEMBER_NUM*2 的随机数组,数组中每个数字的值范围:(0, 9)
*
* @return r[<code>MEMBER_NUM</code> * 2]
*/
public int[] randomResult() {
long seed = System.nanoTime();
Random random = new Random(seed);
int[] r = new int[MEMBER_NUM * 2];
for (int j = 0; j < MEMBER_NUM * 2; j++) {
r[j] = random.nextInt(9);
}
return r;
}
public static void main(String[] args) {
//int[] s = {1, 4, 3, 2, 7, 3, 6, 3, 2, 4, 5, 3};
Optimization o = new Optimization();
//System.out.println(o.getMinutes(o.flights.getTimeByValue("1:21")));
o.initPeople();
//int[] best = o.randomOptimize(100);
//int[] best = o.hillclimb();
// int[] best = o.annealingoptimize(10000.0, 0.95, 2);
int[] best = o.geneticoptimize(50, 2, 0.2, 0.2, 100);
o.printSchedule(best);
}
}
/**
*
* FILENAME: Flights.java
* AUTHOR: vivizhyy[at]gmail.com
* STUID: whu
* DATE: 2010-4-12
* USAGE :
*/
package ch5.optimization;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.joda.time.LocalTime;
public class Flights {
private String origin;
private String dest;
private LocalTime depart;
private LocalTime arraive;
private float price;
public static String SCHEDULE = "schedule.txt";
public Logger log = Logger.getLogger(Flights.class.getName());
public Flights() {
}
public LocalTime getArraive() {
return arraive;
}
public LocalTime getDepart() {
return depart;
}
public String getDest() {
return dest;
}
public Logger getLog() {
return log;
}
public String getOrigin() {
return origin;
}
public float getPrice() {
return price;
}
public void setArraive(LocalTime arraive) {
this.arraive = arraive;
}
public void setDepart(LocalTime depart) {
this.depart = depart;
}
public void setDest(String dest) {
this.dest = dest;
}
public void setLog(Logger log) {
this.log = log;
}
public void setOrigin(String origin) {
this.origin = origin;
}
public void setPrice(float price) {
this.price = price;
}
/**
*
* @return
*/
public static Flights[] getFlights() {
Flights flights[] = new Flights[120];
for(int i = 0; i < 120; i++)
flights[i] = new Flights();
PropertyConfigurator.configure("D:/Documents/NetBeansProjects/CollectiveProgramming/src/log4j.properties");
BufferedReader reader = null;
String line;
try {
reader = new BufferedReader(new FileReader(SCHEDULE));
int count = 0;
while ((line = reader.readLine()) != null && count < 120) {
//init flights
String result[] = line.split(",");
if (result.length == 5) {
flights[count].origin = result[0];
flights[count].dest = result[1];
flights[count].depart = getTimeByValue(result[2]);
flights[count].arraive = getTimeByValue(result[3]);
flights[count].price = new Integer(result[4]);
count++;
} else {
System.err.println("schedule format wrong.");
}
}
} catch (IOException ex) {
ex.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException ex) {
ex.printStackTrace();
}
}
}
return flights;
}
/**
*
* @param time
* @return
*/
public static LocalTime getTimeByValue(String time) {
int mark = time.indexOf(":");
int hour = new Integer(time.substring(0, mark));
int minute = new Integer(time.substring(mark + 1, time.length()));
LocalTime timeResult = new LocalTime(hour, minute);
return timeResult;
}
/**
*
* @param strOrigin
* @param strDest
* @return
*/
public Flights[] getFlightByOriginAndDest(String strOrigin, String strDest) {
Flights[] resultSet = new Flights[15];
Flights flights[] = getFlights();
int count = 0;
for (Flights f : flights) {
if(count > 14)
{
log.error("flight time lager than define.");
break;
}
if (f.origin.equals(strOrigin) && f.dest.equals(strDest)) {
resultSet[count++] = f;
}
}
return resultSet;
}
public String toString() {
return this.origin + "\t"
+ this.dest + "\t"
+ this.depart + "-" + this.arraive + "\t"
+ "$" + this.price;
}
}
/**
*
* FILENAME: Score.java
* AUTHOR: vivizhyy@gmail.com
* STUID: whu200732580127
* DATE: 2010-4-13
* USAGE :
*/
package ch5.optimization;
import java.util.Comparator;
public class Score {
public int[] list = new int[12];
public double price;
}
class CompareScore implements Comparator{
public int compare(Object o1, Object o2) {
Score s1 = (Score)o1;
Score s2 = (Score)o2;
int flag = 0;
if(s1.price > s2.price)
flag = 1;
return flag;
}
}