参考了以下博客的思路,实现了java版本,原博客的实现思路是对的,但是求最小值的方法错误,在此基础上修改了,最大值的贪婪算法没去实现
分支限界TSP算法c版
下图是要求的TSP图
算法如下
import java.util.*;
class Node implements Comparable{
int[] visp;//标记哪些点走了
int st;//起点
int ed;//终点
int k;//走过的点数
int sumv;//经过路径的距离
int lb;//目标函数的值
Map<Integer,Integer> map_edge=new HashMap<>();//记录已经加入的边
@Override
public int compareTo(Object o){
Node node=(Node) o;
if(node.lb<this.lb)
return 1;
else if(node.lb>this.lb)
return -1;
return 0;
}
}
public class BBTSP {
private int[][] mp;
int n;
int up=16;//路径总和上界,第一次为无穷大,后面取每个可行分支的最小值
int low;//路径和最小值
private List<Point> points;
private PriorityQueue<Node> q=new PriorityQueue<>();
private PriorityQueue<Node> q_last=new PriorityQueue<>();//记录每条路径的最后一个节点,以及对应的路径值
public BBTSP(int[][] mp,List<Point> points){
this.mp=mp;
points=this.points;
n=points.size();
sum_origin=points.get(0).getStock();
}
public BBTSP(int[][] mp){
this.mp=mp;
n=5;
}
void get_low()
{
low=0;
for(int i=0; i<n; i++)
{
/*通过排序求两个最小值*/
//double min1=Double.MAX_VALUE,min2=Double.MAX_VALUE;
double[] tmpA=new double[n];
for(int j=0; j<n; j++)
{
//if(i==j) continue;
tmpA[j]=mp[i][j];
}
Arrays.sort(tmpA);//对临时的数组进行排序
low+=tmpA[1]+tmpA[2];
}
low=low%2==0?low/2:(low/2+1);
}
public int get_lb(Node p){
int ret=p.sumv*2;//路径上的点的距离
double min1=Double.MAX_VALUE,min2=Double.MAX_VALUE;//起点和终点连出来的边
/* System.out.println("边:");
for(Map.Entry<Integer,Integer> entry: p.map_edge.entrySet())
System.out.println("start="+entry.getKey()+",end="+entry.getValue());*/
Map<Integer,Integer> map=p.map_edge;
for(int i=0; i<n; i++) {
// System.out.println("++++++++++++i="+i+"+++++++++++++++++++");
boolean flag1 = false;//该点为出点
boolean flag2 = false;//该点为入点
int end = -1;
int start = -1;
if (map.containsKey(i)) {
flag1 = true;
end = map.get(i);
}
if (map.containsValue(i)) {
flag2 = true;
for(Map.Entry<Integer,Integer> entry:map.entrySet())
if(entry.getValue()==i) start=entry.getKey();
}
if (flag1 && flag2) continue;
List<Integer> array=new ArrayList<>();
if (!flag1&&flag2) {//该点只有入点,没有出点
for (int j = 0; j < n; j++) {
if (i == j || j == start) continue;
array.add(mp[i][j]);
//System.out.println("flag1 map["+i+"]"+"["+j+"]="+mp[i][j]);
}
Collections.sort(array);
ret += array.get(0);
// System.out.println("array.get(0)="+array.get(0)+",ret="+ret);
}
if (!flag2&&flag1) {
array=new ArrayList<>();
for (int j = 0; j < n; j++) {
if (i == j || j == end) continue;
array.add(mp[j][i]);
//System.out.println("flag2 map["+j+"]"+"["+i+"]="+mp[j][i]);
}
Collections.sort(array);
ret += array.get(0);
//System.out.println("array.get(0)="+array.get(0)+",ret="+ret);
}
if(!flag1&&!flag2){
array=new ArrayList<>();
for (int j = 0; j < n; j++) {
if (i == j) continue;
array.add(mp[i][j]);
}
Collections.sort(array);
ret += array.get(0)+array.get(1);
// System.out.println("array.get(0)="+array.get(0)+"array.get(1)="+array.get(1)+",ret="+ret);
}
}
System.out.println("2.ret="+ret);
return ret%2==0?(ret/2):(ret/2+1);
}
public Node solve(){
get_low();
System.out.println("low="+low);
/*设置初始点,默认从1开始 */
Node star=new Node();
star.st=0;
star.ed=0;
star.k=1;
star.visp=new int[n];
for(int i=0; i<n; i++) star.visp[i]=0;
star.visp[0]=1;
star.sumv=0;
star.lb=low;
System.out.println("n="+n);
/*ret为问题的解*/
double ret=Double.MAX_VALUE;
q.add(star);
while(!q.isEmpty())
{
Node tmp=q.peek();
if(!q_last.isEmpty()){
Node last=q_last.peek();
if(last.lb<=tmp.lb) return last;
}
System.out.println("输出队列里面的数据");
Iterator<Node> it=q.iterator();
while (it.hasNext()){
Node no=it.next();
System.out.println("node.st="+no.st+";node.ed="+no.ed+";node.sumv="+no.sumv);
}
System.out.println("--------------------------------------------------------");
System.out.println("tmp.st="+tmp.st+";tmp.ed="+tmp.ed+";tmp.sumv="+tmp.sumv);
Map<Integer,Integer> tmp_map=tmp.map_edge;
q.poll();
if(tmp.k==n-1)
{
System.out.println("最后一个点");
/*找最后一个没有走的点*/
int p=0;
for(int i=0; i<n; i++)
{
if(tmp.visp[i]==0)
{
p=i;
break;
}
}
Node next=new Node();
next.visp=new int[n];
next.st=tmp.ed;
next.ed=p;
int ans=tmp.sumv+mp[p][0]+mp[tmp.ed][p];//最终的最短路径
next.sumv=ans;
next.k=tmp.k+1;
next.map_edge.putAll(tmp.map_edge);
next.map_edge.put(next.st,next.ed);
next.map_edge.put(next.ed,0);
next.lb=ans;
System.out.println("next.i="+p+";next.lib="+next.lb+";next.st="+next.st+";next.ed="+next.ed+";next.sumv="+next.sumv);
Node judge = q.peek();
/*如果当前的路径和比所有的目标函数值都小则跳出*/
if(ans <= judge.lb||judge==null)
{
//ret=Math.min(ans,ret);
// ret_map.put(ans,next);
return next;
// break;
}
/*否则继续求其他可能的路径和,并更新上界*/
else
{
up = Math.min(up,ans);
q_last.add(next);
// ret=Math.min(ret,ans);
continue;
}
}
/*当前点可以向下扩展的点入优先级队列*/
for(int i=0; i<n; i++)
{
if(tmp.visp[i]==0)
{
Node next=new Node();
next.visp=new int[n];
next.st=tmp.ed;
/*更新路径和*/
//System.out.println("tmp.sumv="+tmp.sumv);
next.sumv=tmp.sumv+mp[tmp.ed][i];
/*更新最后一个点*/
next.ed=i;
/*更新顶点数*/
next.k=tmp.k+1;
/*更新经过的顶点*/
for(int j=0; j<n; j++) next.visp[j]=tmp.visp[j];
next.visp[i]=1;
/*求目标函数*/
Map<Integer,Integer> next_map=new HashMap<>();
next_map.putAll(tmp_map);
next_map.put(next.st,next.ed);
next.map_edge=next_map;
next.lb=get_lb(next);
System.out.println("sumv="+next.sumv);
System.out.println("next.i="+i+";next.lib="+next.lb+";next.st="+next.st+";next.ed="+next.ed+";next.sumv="+next.sumv);
/*如果大于上界就不加入队列*/
if(next.lb>up){
next_map.remove(next.st);
continue;
}
q.add(next);
}
}
}
// return ret;
return null;
}
public static void main(String[] args){
int[][] d={
{0,3,1,5,8},
{3,0,6,7,9},
{1,6,0,4,2},
{5,7,4,0,3},
{8,9,2,3,0}
};
BBTSP b=new BBTSP(d);
Node node=b.solve();
System.out.println();
System.out.println("+++++++++++++++++++++++++输出结果:++++++++++++++++++++++++");
System.out.println("最后遍历的点的信息:");
System.out.println("node.lib="+node.lb+";node.st="+node.st+";node.ed="+node.ed+";node.sumv="+node.sumv);
System.out.println("最短路径为:"+node.lb);
System.out.println("构成的边为:");
for(Map.Entry<Integer,Integer> entry: node.map_edge.entrySet()){
System.out.println(entry.getKey()+" -> "+entry.getValue());
}
/* Node n1=new Node();
n1.lb=9;
Node n2=new Node();
n2.lb=19;
Node n3=new Node();
n3.lb=11;
Node n4=new Node();
n4.lb=5;
Node n5=new Node();
n5.lb=2;
PriorityQueue<Node> q1=new PriorityQueue<>();
q1.add(n1);
q1.add(n2);
q1.add(n3);
q1.add(n4);
q1.add(n5);
while(!q1.isEmpty()){
Node nn=q1.poll();
System.out.println(nn.lb);
}*/
}
}
最后的结果是:
最后遍历的点的信息:
node.lib=16;node.st=3;node.ed=1;node.sumv=16
最短路径为:16
构成的边为:
0 -> 2
1 -> 0
2 -> 4
3 -> 1
4 -> 3