1787: [Ahoi2008]Meet 紧急集合
Time Limit: 20 Sec Memory Limit:162 MBSubmit: 590 Solved: 240
[Submit][Status][Discuss]
Description
Input
Output
Sample Input
1 2
2 3
2 4
4 5
5 6
4 5 6
6 3 1
2 4 4
6 6 6
Sample Output
5 2
2 5
4 1
6 0
HINT
Source
解析:裸的LCA(即最近公共祖先)。设a、b、c三点,则集合点必为lca(a,b)、lca(a,c)、lca(b,c)中的一个,然后求解个点到集合点的距离就可以了。
1.如何求解树上的两个点a、b的距离:
用 h[i] 记录节点i 在树中的深度,若 j 为 i 的子节点,则有:h[j]==h[i]+1;
k=lca(a,b),则length=h[a]-h[k]+h[b]-h[k]。
2.如何求解最近公共祖先:
①tarjan。(这个这里不讲,请自行百度学习)
②转化为 rmq 问题求解。
若树的形状为下图所示的二叉树,我们对其进行中序遍历,则有:
中序遍历:4 2 5 1 6 3 7
深度: 3 2 3 1 3 2 3
观察发现:数的父节点的深度总是低于自己节点的;中序遍历序列,父节点总位于两子节序列之间。
==》中序遍历中,a、b的父节点即为[a,b]上,深度h最小的点。
于是,求解LCA的问题就转化为求解一段序列的最小值,也即 rmq 问题。
这里只罗列两种解决rmq问题的方法:
第一个:线段树,这个就不多讲了。
第二个:st算法。
我们用a[i][j]便是以 i 为起点,长度为 2^j 的序列上的最小值。
a[i][j]=min(a[i][j-1],a[i+2^(j-1)][j-1]);
查询区间[l,r]的最小值:
i= r - l ,j=(int)(log(j*1.0)/log(2.0)),k=2^j
区间最小值为:min(a[i][j],a[r-k+1][j])。
st与线段树相比,空间消耗是一样的,但是线段树的查询是o(log n),而st的查询是o(1)的。
那现在如果不是二叉树呢?
只要保证生成序列中,每两个子序列之间有一个父节点见他们隔开即可,比如:
遍历序列:4 2 5 1 8 1 6 3 7
我在程序里,为了写的方便,生成的序列是用两个父节点包围一个子序列,即:
1 (2 4 2 5 2) 1 8 1 (3 6 3 7 3) 1
③倍增。
这可谓是处理树形数据的常用方法之一。
up[i][j]表示节点沿着树上的路径,向上走2^j 步所能到达的点,则:
up[i][j]=up[ up[i][j-1] ] [ j-1 ]
其他的具体应用,就去看下方的代码吧。
rmq-st:
以 1 作为根节点建树,h[i]记录点 i 在书中的深度,i的子节点j的深度即为:h[j]=h[i]+1;
b[i]记录 i 在生成序列中的位置(我这里取最后一次出现的地方,其实随便哪里都可以)
a[i][j]记录序列中从 i 开始,长度为2^j的子序列中,h值最小的点,也就是这一子区间内的最近公共祖先
d[i]=2^i
head[]用来存储变边
<span style="font-size:18px;">
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxn=5e5;
int sum,h[maxn+10];
int a[maxn*2+10][21];
int b[maxn+10];
int d[30];
struct tnode{
int x;
tnode *next;
}*head[maxn+10];
void add_edge(int x,int y)
{
tnode *p=new tnode;
(*p).x=y,(*p).next=head[x];
head[x]=p;
}
void build_queue(int x)
{
tnode *p;int k;
a[++sum][0]=x,b[x]=sum;
for(p=head[x];p;p=(*p).next)
{
if(h[(k=(*p).x)])continue;
h[k]=h[x]+1,build_queue(k);
a[++sum][0]=x,b[x]=sum;
}
}
int get_father(int x,int y)
{
int i,j,k,s=b[x],t=b[y];
if(s>t)swap(s,t);
k=(int)(log(t-s+1.0)/log(2.0));
i=a[s][k],j=a[t-d[k]+1][k];
return (h[i]<h[j])?i:j;
}
int get_length(int k,int x,int y,int z)
{
int ans=h[x]+h[y]-2*h[k];
int i=get_father(k,z);
ans+=h[z]+h[k]-2*h[i];
return ans;
}
int main()
{
int n,m,i,j,k,x,y,z,xy,xz,yz,ans1,ans2;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)h[i]=NULL;
for(d[0]=1,i=1;i<=19;i++)d[i]=d[i-1]*2;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add_edge(x,y),add_edge(y,x);
}
h[1]=1,sum=0,build_queue(1);
for(j=1;j<=19;j++)
for(i=1;i+d[j]-1<=sum;i++)
if(h[a[i][j-1]]<h[a[i+d[j-1]][j-1]])
a[i][j]=a[i][j-1];
else a[i][j]=a[i+d[j-1]][j-1];
for(i=1;i<=m;i++)
{
scanf("%d%d%d",&x,&y,&z);
ans1=get_father(x,y);
ans2=get_length(ans1,x,y,z);
k=get_father(x,z),j=get_length(k,x,z,y);
if(j<ans2)ans1=k,ans2=j;
k=get_father(y,z),j=get_length(k,y,z,x);
if(j<ans2)ans1=k,ans2=j;
printf("%d %d\n",ans1,ans2);
}
return 0;
}</span>
倍增:
h[i]:记录点i在树中的深度,子节点j的深度为:h[j]=h[i]+1
up[i][j]:记录点i向上走2^i步所到达的点,up[i][j]=up[up[i][j-1]][j-1]
head[]:记录树中的边
<span style="font-size:18px;">#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=5e5;
int up[maxn+10][20];
int h[maxn+10];
struct tnode{
int x;
tnode *next;
}*head[maxn+10];
void add_edge(int x,int y)
{
tnode *p=new tnode;
(*p).x=y,(*p).next=head[x],head[x]=p;
}
void build_tree(int x)
{
tnode *p;int i,k;
for(p=head[x];p;p=(*p).next)
{
if(h[k=(*p).x])continue;
h[k]=h[x]+1,up[k][0]=x;
for(i=1;i<=18;i++)
if(!(up[k][i]=up[up[k][i-1]][i-1]))break;
build_tree(k);
}
}
int get_father(int x,int y)
{
if(h[x]<h[y])swap(x,y);
int i,len=h[x]-h[y];
for(i=0;i<=18;i++)
if((1<<i)&len)x=up[x][i];
for(i=18;i>=0;i--)
if(up[x][i]!=up[y][i])
x=up[x][i],y=up[y][i];
if(x==y)return x;
return up[x][0];
}
int get_length(int k,int x,int y,int z)
{
int ans=h[x]+h[y]-2*h[k];
int i=get_father(k,z);
return ans+=h[k]+h[z]-2*h[i];
}
int main()
{
int n,m,i,j,k,x,y,z,ans1,ans2;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)head[i]=NULL,h[i]=0;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add_edge(x,y),add_edge(y,x);
}
h[1]=1,up[1][0]=0,build_tree(1);
for(k=1;k<=m;k++)
{
scanf("%d%d%d",&x,&y,&z);
ans1=get_father(x,y),ans2=get_length(ans1,x,y,z);
i=get_father(x,z),j=get_length(i,x,z,y);
if(j<ans2)ans1=i,ans2=j;
i=get_father(y,z),j=get_length(i,y,z,x);
if(j<ans2)ans1=i,ans2=j;
printf("%d %d\n",ans1,ans2);
}
return 0;
}</span>