题意:
求第K短路。注意:S == T时,0不算一条路径。
思路:
(A*算法入门)
众所周知,A*算法就是启发式搜索,基本形式就是这样:f(x) = g(x)+h(x);其中f(x)代表在x点所需要的总代价,而g(x)代表从源点到x点已经耗费的实际代价,h(x)代表从x到终点需要的估计代价,这个函数是一个估计值。而从x到终点真正需要的代价为h*(x),在整个启发式搜索中我们必须保证h(x) <= h*(x);不然的话会由于对当前的估价值过高,则会引起答案的错误。构建A*的关键在于准确的规划一个h(x)函数,使得接近h*(x),这样的搜索会使得答案又快又准。可以想象h(x)过小会使得解的空间过大,这样搜索出来的结果会很准确但是速度太慢,而对h(x)的过高估计,即估计代价太大又会使得结果不准确。
这样我们便可以理解了BFS的搜索过程,BFS的搜索过程中没有考虑到h(x)的估计代价,也就是说h(x) = 0,只考虑g(x)的实际代价。这样根据实际代价来进行搜索,虽然BFS可以说是一个很恶心的A*,同样地我们可以知道,BFS的解空间确实很大。
第一次写A*,目前只会应用在K短路上。不过也有点感觉了,关键在于h(x)的设计!
谈具体的实现:
首先我们在解空间取出的就是f(x)最小的,这样我们就要运用到优先队列了。这里提供一个使用C++系统优先队列的方法:C++的STL中自带了优先队列,通过重载运算法"<",可以实现我们需要的对f(x)的自动维护。
描述一下怎样用启发式搜索来解决K短路:
首先我们知道A*的基础公式:f(x)=g(x)+h(x);对h(x)进行设计,根据定义h(x)为当前的x点到目标点t所需要的估计距离。也就是说x->t距离,由于有很多的节点都是到t的距离,为了计算这个估计值,当然必须先算出x->t的最短路径长度。显然x的表示会很多而t的值只有一个,对每个x去求单源点最短路径当然不划算!于是反过来做,从t点出发到其他点的单源点最短路径,这样把估价函数h(x)就都求出来,注意这样求出来的h(x) = h*(x)。
然后就可以对构造完的h(x)开始启发式搜索了:
首先的点当然就是定义头结点了,头结点的已消耗代价为0,估计代价为h[s],下一个点为v;进入队列,开始while循环。每次取出队头的f(x)最小的节点对其他节点进行拓展。对当前节点的拓展次数++,若当前节点的拓展次数超过K,显然不符合要求,则不进行拓展。若对t节点的拓展次数恰好为K,则找到了所需要的K短路。对当前节点的拓展次数即为到当前节点的第几短路。找到需要节点的K短路后,返回g(t)即可,也就是通过K次拓展的实际消耗的长度。
在while循环中的入队情况:当前节点的所有可拓展边的所有状态都入队,当前节点到被拓展节点的实际代价为当前节点的实际代价+两节点之间的边长。下个节点就是被拓展节点,估计函数的值则为被拓展节点到目标节点的距离h(x)。
(分析来自:http://blog.csdn.net/mbxc816/article/details/7197228)
时间复杂度网上说是:O((n+m)logn + mlogm + klogk)。
Code:
#include <string.h>
#include <cstdio>
#include <queue>
using namespace std;
const int maxn = 1005;
const int maxm = 100005;
struct node1
{
int v, w, next;
} edge[2][maxm];
struct node2
{
int v, g, h;
bool operator <(const node2 k) const
{
return g+h > k.g+k.h;
}
} e;
int head[2][maxn], vis[maxn], h[maxn], cnt[maxn];
int N, M, S, T, K, no1, no2;
priority_queue<node2> q;
queue<int> q1;
inline void init()
{
no1 = no2 = 0;
memset(head, -1, sizeof head);
memset(cnt, 0, sizeof cnt);
memset(h, 0x3f, sizeof h);
while(!q1.empty()) q1.pop();
while(!q.empty()) q.pop();
}
inline void add(int u, int v, int w) //第二维用于求逆向最短路
{
edge[0][no1].v = v, edge[0][no1].w = w;
edge[0][no1].next = head[0][u]; head[0][u] = no1++;
edge[1][no2].v = u, edge[1][no2].w = w;
edge[1][no2].next = head[1][v]; head[1][v] = no2++;
}
void SPFA() //逆向最短路
{
h[T] = 0; vis[T] = 1;
q1.push(T);
while(!q1.empty())
{
int tp = q1.front(); q1.pop();
vis[tp] = 0;
int k = head[1][tp];
while(k != -1)
{
if(h[edge[1][k].v] > h[tp]+edge[1][k].w)
{
h[edge[1][k].v] = h[tp]+edge[1][k].w;
if(!vis[edge[1][k].v])
{
vis[edge[1][k].v] = 1;
q1.push(edge[1][k].v);
}
}
k = edge[1][k].next;
}
}
}
int AStar_Kth()
{
e.v = S, e.g = 0, e.h = h[S];
q.push(e);
while(!q.empty())
{
node2 tp = q.top(); q.pop();
++cnt[tp.v];
if(cnt[tp.v] > K) continue;
if(cnt[T] == K) return tp.g;
int k = head[0][tp.v];
while(k != -1)
{
e.v = edge[0][k].v;
e.g = tp.g + edge[0][k].w;
e.h = h[edge[0][k].v];
//注意:过程中的e.g可能会变小,但e.g+e.h是一直不下降的
q.push(e);
k = edge[0][k].next;
}
}
return -1;
}
int main()
{
int u, v, w;
scanf("%d %d", &N, &M); init();
for(int i = 1; i <= M; ++i)
{
scanf("%d %d %d", &u, &v, &w);
add(u, v, w);
}
scanf("%d %d %d", &S, &T, &K);
if(S == T) ++K; //S==T时,路径为0不算一条路径
SPFA();
printf("%d\n", AStar_Kth());
return 0;
}
继续加油~