Tarjan_lca
首先,以一道题引入:(出自洛谷P3379)
【题目描述】
给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。
【输入输出格式】
输入格式:
第一行包含三个正整数N、M、S,分别表示树的结点个数、询问的个数和树根结点的序号。
接下来N-1行每行包含两个正整数x、y,表示x结点和y结点只见有一条直接连接的边(数据保证可以构成树)。
接下来M行每行包含两个正整数a、b,表示询问a结点和b结点的最近公共祖先。
输出格式:
输出包含M行,每行包含一个正整数,依次为每一个询问的结果。
【输入输出样例】
输入样例1:
5 5 4
3 1
2 4
5 1
1 4
2 4
3 2
3 5
1 2
4 5
输出样例1:
4
4
1
4
4
【说明】
时空限制1000ms,128M
数据规模:
- 对于30%的数据:N<=10,M<=10
- 对于70%的数据:N<=10000,M<=10000
- 对于100%的数据:N<=500000,M<=500000
样例说明:
该树结构如下:
第一次询问:2、4的最近公共祖先,故为4。
第二次询问:3、2的最近公共祖先,故为4。
第三次询问:3、5的最近公共祖先,故为1。
第四次询问:1、2的最近公共祖先,故为4。
第五次询问:4、5的最近公共祖先,故为4。
故输出依次为4、4、1、4、4。
题目明明白白,求lca,大家很容易想到倍增lca,时间复杂度为O(nlogn+mlogn),这样不会超时,但如果题目复杂一点,运用次数多很容易会爆。那么我要用一种O(n+m)的方法去求lca。
对于以U为根的子树,树内任意一点V,lca(U,V)=U,很显然。
U有不同的子节点u1,u2,以u1为根的子树中有一节点v1,以u2为根的子树中有一节点v2,则lca(v1,v2)=U。
那么我们就可以如下解决:
我们利用并查集的思想。
将询问弄成一个前向星。
对于一个节点,在处理完它的所有子节点后,枚举它的询问,如果它的对象被查找过,则它们的lca为它的对象的并查集中的祖先。
解释:对于一棵子树,它的根节点U,U的询问对象为V,V被查找过,那么有两种可能:
1. U在以V为根的子树中,则并查集中Father(V)=V
2. U与V没有祖先后代关系,则有lca(U,V)=Father(V),因为在搜lca(U,V)的子节点时,先搜了V再搜U,这时Father(V)已经指向了lca(U,V)
最后将该节点在并查集中指向它的父亲节点。
事实上它是O(α(n)n+m)的
#include<cstdio>
#include<cstring>
using namespace std;
int to[1001001],nex[1001001],fir[500100],qto[1001001],qnex[1001001],qfir[500100],n,m,S,ans[500100],fa[500100],top;
bool visited[500100];
int find(int w)
{
if(fa[w]!=w)fa[w]=find(fa[w]);
return fa[w];
}
void Tarjan_lca(int w)
{
visited[w]=1;
for(int i=fir[w];i;i=nex[i])
{
if(visited[to[i]])continue;
Tarjan_lca(to[i]);
fa[to[i]]=w;
}
for(int i=qfir[w];i;i=qnex[i])
{
int x=qto[i];
if(visited[x])
{
ans[i>>1]=find(x);
}
}
}
int main()
{
scanf("%d%d%d",&n,&m,&S);
for(int i=1;i<=n;i++)fa[i]=i;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
to[++top]=y;nex[top]=fir[x];fir[x]=top;
to[++top]=x;nex[top]=fir[y];fir[y]=top;
}
top=1;
for(int i=1;i<=m;i++)
{
int x,y;
scanf("%d%d",&x,&y);
qto[++top]=y;qnex[top]=qfir[x];qfir[x]=top;
qto[++top]=x;qnex[top]=qfir[y];qfir[y]=top;
}
Tarjan_lca(S);
for(int i=1;i<=m;i++)printf("%d\n",ans[i]);
}
可能在对于比较深层的树中需要人工栈。
#include<cstdio>
#include<cstring>
using namespace std;
int to[1001001],nex[1001001],fir[500100],qto[1001001],qnex[1001001],qfir[500100],n,m,S,ans[500100],fa[500100],top,c[500100],stack[500100],las[500100];
bool visited[500100];
int find(int w)
{
if(fa[w]!=w)fa[w]=find(fa[w]);
return fa[w];
}
void Tarjan_lca(int w)
{
visited[w]=1;
stack[++top]=w;
c[top]=fir[w];
las[top]=0;
while(top)
{
bool p=0;
w=stack[top];
fa[las[top]]=w;
for(int i=c[top];i;i=nex[i])
{
if(visited[to[i]])continue;
las[top]=to[i];
stack[++top]=to[i];
c[top]=fir[to[i]];
las[top]=0;
p=1;
visited[to[i]]=1;
break;
}
if(p)continue;
for(int i=qfir[w];i;i=qnex[i])
{
int x=qto[i];
if(visited[x])
{
ans[i>>1]=find(x);
}
}
top--;
}
}
int main()
{
scanf("%d%d%d",&n,&m,&S);
for(int i=1;i<=n;i++)fa[i]=i;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
to[++top]=y;nex[top]=fir[x];fir[x]=top;
to[++top]=x;nex[top]=fir[y];fir[y]=top;
}
top=1;
for(int i=1;i<=m;i++)
{
int x,y;
scanf("%d%d",&x,&y);
qto[++top]=y;qnex[top]=qfir[x];qfir[x]=top;
qto[++top]=x;qnex[top]=qfir[y];qfir[y]=top;
}
top=0;
Tarjan_lca(S);
for(int i=1;i<=m;i++)printf("%d\n",ans[i]);
}
再例如NOIP2015提高的最后一题,运输计划,多少个人95分:
#include<cstdio>
#include<cstring>
using namespace std;
const int N=300300;
int n,m,q[N][2],to[2*N],next[2*N],fir[N],las[N],f[21][N],poi[N],tot=0,floor[N],bz[N],data[N],loc[N],top=0;
long long ans,len[N],v[2*N],p[N],dis[21][N],leng[N],max=0,d[N];
void swap(int &x,int &y){int z=x;x=y;y=z;}
void swap(long long &x,long long &y){long long z=x;x=y;y=z;}
void put(int,int,long long);
void build(int x);
long long get(int,int);
void find(int,int);
void qs(int,int);
int Get(int);
int main()
{
ans=100000000000000;
freopen("transport.in","r",stdin);
freopen("transport.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y;
long long z;
scanf("%d%d%lld",&x,&y,&z);
put(x,y,z),put(y,x,z);
}
memset(bz,0,sizeof(bz));
memset(f,0,sizeof(f));
memset(dis,0,sizeof(dis));
memset(floor,0,sizeof(floor));
floor[1]=1;
build(1);
for(int j=1;j<21;j++)
{
for(int i=1;i<=n;i++)
{
f[j][i]=f[j-1][f[j-1][i]];
dis[j][i]=dis[j-1][i]+dis[j-1][f[j-1][i]];
}
}
int x,y;
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i][0],&q[i][1]);
leng[i]=get(q[i][0],q[i][1]);
}
qs(1,m);
memset(bz,0,sizeof(bz));
find(q[1][0],q[1][1]);
memset(bz,0,sizeof(bz));
for(int i=1;i<=top;i++)
{
loc[poi[i]]=i;
data[i]=poi[i];
bz[poi[i]]=poi[i];
}
x=0,y=top;
while(x<y)
{
int now=data[++x];
for(int i=fir[now];i;i=next[i])
{
int o=to[i];
if(!bz[o])
{
bz[o]=bz[now];
data[++y]=o;
}
}
}
x=1;y=top;
for(int i=2;i<=m;i++)
{
int xx=Get(q[i][0]),yy=Get(q[i][1]);
if(xx>yy)swap(xx,yy);
if(xx>=y)
{
len[x]+=leng[i];
len[y]-=leng[i];
break;
}
if(yy<=x)
{
len[x]+=leng[i];
len[y]-=leng[i];
break;
}
if(xx>x)
{
len[xx]-=leng[i];
len[x]+=leng[i];
x=xx;
}
if(yy<y)
{
len[yy]+=leng[i];
len[y]-=leng[i];
y=yy;
}
}
for(int i=1;i<top;i++)
{
len[i]+=len[i-1];
long long temp=leng[1]-d[i];
if(temp<len[i])temp=len[i];
if(temp<ans)ans=temp;
}
if(top-1)printf("%lld",ans);else printf("0");
}
int Get(int x)
{
return loc[bz[x]];
}
void qs(int h,int t)
{
int l=h,r=t;
long long n=leng[(h+t)>>1];
do
{
while(leng[l]>n)l++;
while(leng[r]<n)r--;
if(l<=r)
{
swap(q[l][0],q[r][0]);
swap(q[l][1],q[r][1]);
swap(leng[l],leng[r]);
l++;r--;
}
}while(l<=r);
if(h<r)qs(h,r);if(l<t)qs(l,t);
}
void find(int x,int t)
{
bz[x]=1;
poi[++top]=x;
if(t==x)return;
for(int i=fir[x];i;i=next[i])
{
int now=to[i];
if(bz[now])continue;
d[top]=v[i];
find(now,t);
if(t==poi[top])return;
}
top--;
}
long long get(int x,int y)
{
if(floor[x]>floor[y])swap(x,y);
long long ans=0;
for(int i=20;i+1;i--)
{
if(floor[f[i][y]]>=floor[x])
{
ans+=dis[i][y];
y=f[i][y];
if(floor[x]==floor[y])break;
}
}
if(x==y)return ans;
for(int i=20;i+1;i--)
{
if(f[i][x]!=f[i][y])
{
ans+=dis[i][x]+dis[i][y];
x=f[i][x];y=f[i][y];
}
}
ans+=dis[0][x]+dis[0][y];
return ans;
}
void build(int x)
{
bz[x]=1;
for(int i=fir[x];i;i=next[i])
{
int n=to[i];
if(!bz[n])
{
f[0][n]=x;
dis[0][n]=v[i];
floor[n]=floor[x]+1;
build(n);
}
}
}
void put(int x,int y,long long z)
{
to[++tot]=y;next[tot]=0;v[tot]=z;
if(fir[x])next[las[x]]=tot;else fir[x]=tot;
las[x]=tot;
}
这是用倍增的结果,死命改也不对。
然后换了Tarjan lca:
#include<cstdio>
#include<cstring>
using namespace std;
const int N=300300;
int n,m,q[N][3],to[2*N],next[2*N],fir[N],f[21][N],poi[N],tot=0,floor[N],bz[N],data[N];
int loc[N],top=0,up[N][2],qto[2*N],qfir[N],fa[N],qnex[2*N],Top=0,st[N],c[N],qown[N*2],C[N];
long long ans,len[N],v[2*N],p[N],dis[21][N],leng[N],max=0,d[N],dist[N];
void swap(int &x,int &y){int z=x;x=y;y=z;}
void swap(long long &x,long long &y){long long z=x;x=y;y=z;}
void build(int x);
void qs(int,int);
int Find(int w)
{
if(fa[w]!=w)fa[w]=Find(fa[w]);
return fa[w];
}
void Tarjan_lca(int w)
{
c[1]=fir[w];
C[1]=0;
st[1]=w;
Top=1;
while(Top)
{
bool p=0;
w=st[Top];
bz[w]=1;
fa[to[C[Top]]]=w;
for(int i=c[Top];i;i=next[i])
{
if(!bz[to[i]])
{
C[Top]=i;
c[Top]=next[i];
st[++Top]=to[i];
c[Top]=fir[to[i]];
C[Top]=0;
p=1;break;
}
}
if(p)continue;
for(int i=qfir[w];i;i=qnex[i])
{
int x=qto[i],s=qown[i];
if(bz[x])
{
if(s*2==i)
{
q[s][2]=Find(x);
leng[s]=dist[w]+dist[x]-2*dist[q[s][2]];
up[s][1]=floor[w]-floor[q[s][2]];
up[s][0]=floor[x]-floor[q[s][2]];
}else
{
q[s][2]=Find(x);
leng[s]=dist[w]+dist[x]-2*dist[q[s][2]];
up[s][0]=floor[w]-floor[q[s][2]];
up[s][1]=floor[x]-floor[q[s][2]];
}
}
}
Top--;
}
}
int Read()
{
int f=1,p=0;char c=getchar();
while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
while(c>='0' && c<='9'){p=p*10+c-'0';c=getchar();}
return f*p;
}
long long read()
{
long long f=1,p=0;char c=getchar();
while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
while(c>='0' && c<='9'){p=p*10+c-'0';c=getchar();}
return f*p;
}
int main()
{
ans=100000000000000;
freopen("transport.in","r",stdin);
freopen("transport.out","w",stdout);
scanf("%d%d",&n,&m);
memset(fir,0,sizeof(fir));
memset(qfir,0,sizeof(qfir));
memset(next,0,sizeof(next));
memset(qnex,0,sizeof(qnex));
for(int i=1;i<n;i++)
{
int x,y;
long long z;
x=Read();
y=Read();
z=read();
to[++tot]=y;next[tot]=fir[x];v[tot]=z;
fir[x]=tot;
to[++tot]=x;next[tot]=fir[y];v[tot]=z;
fir[y]=tot;
}
for(int i=1;i<=n;i++)fa[i]=i;
memset(bz,0,sizeof(bz));
memset(f,0,sizeof(f));
memset(dis,0,sizeof(dis));
memset(floor,0,sizeof(floor));
floor[1]=1;
build(1);
for(int j=1;j<21;j++)
{
for(int i=1;i<=n;i++)
{
f[j][i]=f[j-1][f[j-1][i]];
dis[j][i]=dis[j-1][i]+dis[j-1][f[j-1][i]];
}
}
int x,y;
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i][0],&q[i][1]);
qto[++Top]=q[i][1];
qnex[Top]=qfir[q[i][0]];
qfir[q[i][0]]=Top;
qown[Top]=i;
qto[++Top]=q[i][0];
qnex[Top]=qfir[q[i][1]];
qfir[q[i][1]]=Top;
qown[Top]=i;
}
memset(bz,0,sizeof(bz));
Tarjan_lca(1);
qs(1,m);
int point=q[1][0];
top=1;
poi[1]=q[1][0];
d[1]=dis[0][point];
for(int i=1;i<=up[1][0];i++)
{
point=f[0][point];
d[++top]=dis[0][point];
poi[top]=point;
}
poi[++top]=q[1][1];
int wz=top;
point=q[1][1];
d[top]=dis[0][point];
for(int i=1;i<up[1][1];i++)
{
point=f[0][point];
d[++top]=dis[0][point];
poi[top]=point;
}
for(int i=wz;i<=(wz+top)/2;i++)swap(poi[i],poi[top-i+wz]),swap(d[i],d[top-i+wz]);
for(int i=wz-1;i<top;i++)d[i]=d[i+1];
memset(bz,0,sizeof(bz));
for(int i=1;i<=top;i++)
{
loc[poi[i]]=i;
data[i]=poi[i];
bz[poi[i]]=poi[i];
}
x=0,y=top;
while(x<y)
{
int now=data[++x];
for(int i=fir[now];i;i=next[i])
{
int o=to[i];
if(!bz[o])
{
bz[o]=bz[now];
data[++y]=o;
}
}
}
x=1;y=top;
for(int i=2;i<=m;i++)
{
int xx=loc[bz[q[i][0]]],yy=loc[bz[q[i][1]]];
if(xx>yy)swap(xx,yy);
if(xx>=y)
{
len[x]+=leng[i];
len[y]-=leng[i];
break;
}
if(yy<=x)
{
len[x]+=leng[i];
len[y]-=leng[i];
break;
}
if(xx>x)
{
len[xx]-=leng[i];
len[x]+=leng[i];
x=xx;
}
if(yy<y)
{
len[yy]+=leng[i];
len[y]-=leng[i];
y=yy;
}
}
for(int i=1;i<top;i++)
{
len[i]+=len[i-1];
long long temp=leng[1]-d[i];
if(temp<len[i])temp=len[i];
if(temp<ans)ans=temp;
}
if(top-1)printf("%lld",ans);else printf("0");
}
void qs(int h,int t)
{
int l=h,r=t;
long long n=leng[(h+t)>>1];
do
{
while(leng[l]>n)l++;
while(leng[r]<n)r--;
if(l<=r)
{
int t;
long long t1;
t=q[l][0];q[l][0]=q[r][0];q[r][0]=t;
t=q[l][1];q[l][1]=q[r][1];q[r][1]=t;
t1=leng[l];leng[l]=leng[r];leng[r]=t1;
t=up[l][0];up[l][0]=up[r][0];up[r][0]=t;
t=up[l][1];up[l][1]=up[r][1];up[r][1]=t;
l++;r--;
}
}while(l<=r);
if(h<r)qs(h,r);if(l<t)qs(l,t);
}
void build(int x)
{
bz[x]=1;
data[1]=x;
int h=0,t=1;
while(h<t)
{
int now=data[++h];
for(int i=fir[now];i;i=next[i])
{
x=to[i];
if(!bz[x])
{
bz[x]=1;
dist[x]=dist[now]+v[i];
f[0][x]=now;
dis[0][x]=v[i];
floor[x]=floor[now]+1;
data[++t]=x;
}
}
}
}