今天学了一下tarjan求lca(离线的),时间复杂度为O(n*a(n)),就是并查集的时间复杂度。
对于一个询问(u,v),我们先把它加进u开头的与v开头的邻接链表。然后做一遍Dfs。我们肯定会先Dfs到其中一个节点(假设是u),在Dfs到另外一个(v)。那么我们在Dfs到u的时候,把vis[u]标为true,然后在Dfs到v的时候,我们处理所有vis为true的v的邻接链表中的u。它们的lca便为当前的fa[u](这个很容易证明)。在做一个节点的时候,先Dfs它的子树,把它子节点的fa标记为它,然后把它自己标记为true,最后处理它的邻接链表。
我这个写得很简略(给自己复习用的)。这里有个blog写得不错,真心推荐一下(我当初就是看这个会的):
http://www.mamicode.com/info-detail-1067269.html
当然啦,学了之后把它套在了running(Noip2016 Day2 T2)的树剖做法上面,把原先那个垃圾的倍增求lca踢掉了。然而还是超一个点……(还是用树上差分好)
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<set>
using namespace std;
const int maxn=300100;
const int maxl=19;
struct data
{
int obj,_Next;
} e[maxn<<1];
int head[maxn];
int cur=-1;
struct data1
{
int id,val,num,_Next1;
} e1[maxn*maxl<<2];
int head1[maxn];
int cur1=-1;
int fa[maxn][maxl];
int dep[maxn];
int _Size[maxn];
int _Son[maxn];
int w[maxn];
int _Time;
int dfsx[maxn];
int top[maxn];
int que[maxn];
int he=0,ta=1;
int A[maxn];
int B[maxn];
int cntA[maxn];
int cntB[maxn<<1];
struct data2
{
int obj2,id2,_Next2;
} e2[maxn<<1];
int head2[maxn];
int cur2=-1;
int father[maxn];
bool vis[maxn];
struct data3
{
int nu,nv,nlca;
} ask[maxn];
int ans[maxn];
int n,m;
void Add(int x,int y)
{
cur++;
e[cur].obj=y;
e[cur]._Next=head[x];
head[x]=cur;
}
void Bfs1()
{
que[1]=1;
fa[1][0]=1;
dep[1]=1;
while (he<ta)
{
he++;
int node=que[he];
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node][0])
{
fa[son][0]=node;
dep[son]=dep[node]+1;
ta++;
que[ta]=son;
}
p=e[p]._Next;
}
}
}
void Bfs2()
{
for (int i=1; i<=n; i++) _Size[i]=1;
for (int i=n; i>=2; i--)
{
int son=que[i];
int node=fa[son][0];
_Size[node]+=_Size[son];
if (_Size[son]>_Size[ _Son[node] ])
_Son[node]=son;
}
}
void Bfs3()
{
top[1]=1;
w[1]=1;
dfsx[1]=1;
for (int i=1; i<n; i++)
{
int node=que[i];
int heavy_son=_Son[node];
_Time=w[node]+1;
if (heavy_son!=0)
{
top[heavy_son]=top[node];
w[heavy_son]=w[node]+1;
dfsx[ w[heavy_son] ]=heavy_son;
_Time+=_Size[heavy_son];
}
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( son!=heavy_son && son!=fa[node][0] )
{
top[son]=son;
w[son]=_Time;
dfsx[ w[son] ]=son;
_Time+=_Size[son];
}
p=e[p]._Next;
}
}
}
void Make_fa()
{
for (int j=1; j<maxl; j++)
for (int i=1; i<=n; i++)
fa[i][j]=fa[ fa[i][j-1] ][j-1];
}
void Add2(int x,int y,int nid)
{
cur2++;
e2[cur2].obj2=y;
e2[cur2].id2=nid;
e2[cur2]._Next2=head2[x];
head2[x]=cur2;
}
int Find_fa(int x)
{
if (father[x]==x) return x;
return father[x]=Find_fa(father[x]);
}
void Dfs(int node)
{
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node][0])
{
Dfs(son);
father[son]=node;
}
p=e[p]._Next;
}
vis[node]=true;
p=head2[node];
while (p!=-1)
{
int v=e2[p].obj2;
if (vis[v]) ask[ e2[p].id2 ].nlca=Find_fa(v);
p=e2[p]._Next2;
}
}
void Add1(int x,int nid,int nval,int nnum)
{
cur1++;
e1[cur1].id=nid;
e1[cur1].val=nval;
e1[cur1].num=nnum;
e1[cur1]._Next1=head1[x];
head1[x]=cur1;
}
void Plus(int u,int v,int nid,int nval)
{
if (top[u]==top[v])
{
if (w[u]>w[v]) swap(u,v);
Add1(w[u],nid,nval,1);
Add1(w[v]+1,nid,nval,-1);
return;
}
if (dep[ top[u] ]<dep[ top[v] ]) swap(u,v);
int tu=top[u];
Add1(w[tu],nid,nval,1);
Add1(w[u]+1,nid,nval,-1);
Plus(fa[tu][0],v,nid,nval);
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
Add(a,b);
Add(b,a);
}
Bfs1();
Bfs2();
Bfs3();
Make_fa();
for (int i=1; i<=n; i++)
{
int W;
scanf("%d",&W);
A[i]=dep[i]+W;
B[i]=dep[i]-W;
}
for (int i=1; i<=n; i++) head2[i]=-1;
for (int i=1; i<=m; i++)
{
int u,v;
scanf("%d%d",&u,&v);
Add2(u,v,i);
Add2(v,u,i);
ask[i].nu=u;
ask[i].nv=v;
}
for (int i=1; i<=n; i++) father[i]=i;
Dfs(1);
for (int i=1; i<=n+1; i++) head1[i]=-1;
for (int i=1; i<=m; i++)
{
int u=ask[i].nu;
int v=ask[i].nv;
int lca=ask[i].nlca;
int len=dep[u]+dep[v]-(dep[lca]<<1);
Plus(u,lca,0,dep[u]);
Plus(v,lca,1,dep[v]-len);
if (A[lca]==dep[u]) ans[lca]--;
}
for (int i=1; i<=n; i++)
{
int p=head1[i];
while (p!=-1)
{
int nid=e1[p].id;
int nval=e1[p].val;
int nnum=e1[p].num;
if (nid==0) cntA[nval]+=nnum;
else cntB[nval+maxn]+=nnum;
p=e1[p]._Next1;
}
int node=dfsx[i];
ans[node]+=cntA[ A[node] ];
ans[node]+=cntB[ B[node]+maxn ];
}
for (int i=1; i<=n; i++) printf("%d ",ans[i]);
printf("\n");
return 0;
}