给你N个点和M条边的有向图,其中第N个点是源点
让你求每个节点#I关于点#N的关键点的编号和
解题思路:
这题是2013 Multi-University Training Contest 9的题目,官方题解是用bitset或者floyd乱搞
然而丁神告诉我们这题是一题裸的Lengauer_Tarjan算法,具体算法见Tarjan论文,有详细的伪代码
附上丁神的模板
Lengauer-Tarjan algorithm的作用是求出dominator tree,这棵树上的每一个节点都是其儿子的idom
明显我们求出这棵树之后,原题要求的就是每个节点到根的路径上的各节点的编号和
O(MlogM)的算法和O(ma(N,M))的优化在于EVAL和LINK函数
succ数组存的是原图
fa数组存的是i结点的先驱,在dfs生成树上的父亲
dfn数组存的是i结点的新编号,redfn存的是i结点的原编号
prod数组存根据dfs重新排序之后的图
semi数组存半必经点的新编号,表示的是在dfs树上,节点i的祖先中,可以通过一系列的非树边走到i的,深度最小的祖先,i的直系父亲也可以是半必经点
bucket数组存的是,以i作为半必经点的点
idom数组存的是,i的immediate dominator
anc数组:step3里把点都先当成孤立的森林,然后每访问一个点,就将他和他父亲连边,anc数组存的就是结点的父 亲,并且compress用了类似并查集的方式压缩,以此更快速地找到深度最小的祖先
<pre name="code" class="cpp">#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <algorithm>
#include <vector>
#include <queue>
#include <deque>
#include <utility>
#include <map>
#include <set>
#include <cctype>
#include <climits>
#include <stack>
#include <cmath>
#include <bitset>
#include <numeric>
#include <functional>
using namespace std;
vector<int> succ[50010], prod[50010], bucket[50010], dom_t[50010];
int semi[50010], anc[50010], idom[50010], best[50010], fa[50010];
int dfn[50010], redfn[50010];
int timestamp;
void dfs(int now)
{
dfn[now] = ++timestamp;
redfn[timestamp] = now;
anc[timestamp] = idom[timestamp] = 0;
semi[timestamp] = best[timestamp] = timestamp;
int sz = succ[now].size();
for(int i = 0; i < sz; ++i)
{
if(dfn[succ[now][i]] == -1)
{
dfs(succ[now][i]);
fa[dfn[succ[now][i]]] = dfn[now];
}
prod[dfn[succ[now][i]]].push_back(dfn[now]);
}
}
void compress(int now)
{
if(anc[anc[now]] != 0)
{
compress(anc[now]);
if(semi[best[now]] > semi[best[anc[now]]])
best[now] = best[anc[now]];
anc[now] = anc[anc[now]];
}
}
int eval(int now)
{
if(anc[now] == 0)
return now;
compress(now);
return best[now];
}
void debug()
{
for(int i=timestamp;i>1;i--)
cout<<redfn[i]<<" "<<redfn[anc[i]]<<" "<<redfn[fa[i]]<<endl;
cout<<"---------------------"<<endl;
}
void lengauer_tarjan(int n)
{
memset(dfn, -1, sizeof dfn);
memset(fa, -1, sizeof fa);
// memset(anc, 0, sizeof anc);
// memset(idom, 0, sizeof idom);
// for(int i = 0; i <= n; ++i)
// best[i] = semi[i] = i;
timestamp = 0;
dfs(n);
fa[1] = 0;
for(int w = timestamp; w > 1; --w)
{
if(fa[w] == -1)
continue;
int sz = prod[w].size();
for(int i = 0; i < sz; ++i)
{
int u = eval(prod[w][i]);
if(semi[w] > semi[u])
semi[w] = semi[u];
}
debug();
bucket[semi[w]].push_back(w);
anc[w] = fa[w];
if(fa[w] == 0)
continue;
sz = bucket[fa[w]].size();
for(int i = 0; i < sz; ++i)
{
int u = eval(bucket[fa[w]][i]);
if(semi[u] < fa[w])
idom[bucket[fa[w]][i]] = u;
else
idom[bucket[fa[w]][i]] = fa[w];
}
bucket[fa[w]].clear();
}
for(int w = 2; w <= n; ++w)
{
if(fa[w] == -1)
continue;
if(idom[w] != semi[w])
idom[w] = idom[idom[w]];
}
idom[1] = 0;
for(int i = timestamp; i > 1; --i)
{
if(fa[i] == -1)
continue;
dom_t[idom[i]].push_back(i);
}
}
long long ans[50010];
void get_ans(int now)
{
ans[redfn[now]] += redfn[now];
int sz = dom_t[now].size();
for(int i = 0; i < sz; ++i)
{
ans[redfn[dom_t[now][i]]] += ans[redfn[now]];
get_ans(dom_t[now][i]);
}
}
void MAIN(int n, int m)
{
for(int i = 0; i <= n; ++i)
succ[i].clear(), prod[i].clear(), bucket[i].clear(), dom_t[i].clear();
for(int i = 0, u, v; i < m; ++i)
{
scanf("%d%d", &u, &v);
succ[u].push_back(v);
}
lengauer_tarjan(n);
memset(ans, 0, sizeof ans);
get_ans(1);
for(int i = 1; i <= n; ++i)
printf("%I64d%c", ans[i], i == n ? '\n' :' ');
}
int main()
{
int n, m;
while(scanf("%d%d", &n, &m) > 0)
MAIN(n, m);
return 0;
}
O(ma(N,M))模板来源丁神
</pre><pre name="code" class="cpp">// whn6325689
// Mr.Phoebe
// http://blog.csdn.net/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
//Auther:winoros ding
//input : succ
//output dom_t and idom redfn
//notice that the index i in dom_t[i] and dom_t[i][j] is the vertex's timestamp in dfs
//hence you need redfn[i] to find the original vertex
//o(mlogm) where m is the number of edges
//UPD:new version is o(m|á(m, n)), the previous version is in comment
const int vector_num=50000; //max number of vertices
vector<int> succ[vector_num+10], prod[vector_num+10], bucket[vector_num+10], dom_t[vector_num+10];
int semi[vector_num+10], anc[vector_num+10], idom[vector_num+10], best[vector_num+10], fa[vector_num+10];
int dfn[vector_num+10], redfn[vector_num+10];
int child[vector_num+10], size[vector_num+10];
int timestamp;
void dfs(int now)
{
dfn[now] = ++timestamp;
redfn[timestamp] = now;
anc[timestamp] = idom[timestamp] = child[timestamp] = size[timestamp] = 0;
semi[timestamp] = best[timestamp] = timestamp;
int sz = succ[now].size();
for(int i = 0; i < sz; ++i)
{
if(dfn[succ[now][i]] == -1)
{
dfs(succ[now][i]);
fa[dfn[succ[now][i]]] = dfn[now];
}
prod[dfn[succ[now][i]]].push_back(dfn[now]);
}
}
void compress(int now)
{
if(anc[anc[now]] != 0)
{
compress(anc[now]);
if(semi[best[now]] > semi[best[anc[now]]])
best[now] = best[anc[now]];
anc[now] = anc[anc[now]];
}
}
inline int eval(int now)
{
if(anc[now] == 0)
return now;
else
{
compress(now);
return semi[best[anc[now]]] >= semi[best[now]] ? best[now]
: best[anc[now]];
}
}
inline void link(int v, int w)
{
int s = w;
while(semi[best[w]] < semi[best[child[w]]])
{
if(size[s] + size[child[child[s]]] >= 2*size[child[s]])
{
anc[child[s]] = s;
child[s] = child[child[s]];
}
else
{
size[child[s]] = size[s];
s = anc[s] = child[s];
}
}
best[s] = best[w];
size[v] += size[w];
if(size[v] < 2*size[w])
swap(s, child[v]);
while(s != 0)
{
anc[s] = v;
s = child[s];
}
}
void lengauer_tarjan(int n) // n is the vertices' number
{
memset(dfn, -1, sizeof dfn);
memset(fa, -1, sizeof fa);
timestamp = 0;
dfs(n);
fa[1] = 0;
for(int w = timestamp; w > 1; --w)
{
int sz = prod[w].size();
for(int i = 0; i < sz; ++i)
{
int u = eval(prod[w][i]);
if(semi[w] > semi[u])
semi[w] = semi[u];
}
bucket[semi[w]].push_back(w);
//anc[w] = fa[w]; link operation for o(mlogm) version
link(fa[w], w);
if(fa[w] == 0)
continue;
sz = bucket[fa[w]].size();
for(int i = 0; i < sz; ++i)
{
int u = eval(bucket[fa[w]][i]);
if(semi[u] < fa[w])
idom[bucket[fa[w]][i]] = u;
else
idom[bucket[fa[w]][i]] = fa[w];
}
bucket[fa[w]].clear();
}
for(int w = 2; w <= timestamp; ++w)
{
if(idom[w] != semi[w])
idom[w] = idom[idom[w]];
}
idom[1] = 0;
for(int i = timestamp; i > 1; --i)
{
if(fa[i] == -1)
continue;
dom_t[idom[i]].push_back(i);
}
}
long long ans[50010];
void get_ans(int now)
{
ans[redfn[now]] += redfn[now];
int sz = dom_t[now].size();
for(int i = 0; i < sz; ++i)
{
ans[redfn[dom_t[now][i]]] += ans[redfn[now]];
get_ans(dom_t[now][i]);
}
}
void init(int n, int m)
{
for(int i=0; i<=n; i++)
succ[i].clear(), prod[i].clear(), bucket[i].clear(), dom_t[i].clear();
CLR(ans,0);
}
int main()
{
int n, m;
while(read(n)&&read(m))
{
init(n,m);
for(int i=0,u,v; i<m; i++)
{
read(u), read(v);
succ[u].push_back(v);
}
lengauer_tarjan(n);
get_ans(1);
for(int i=1; i<=n; i++)
printf("%I64d%c", ans[i], i == n ? '\n' :' ');
}
return 0;
}