一道树剖裸题(学校OJ上的):
树链剖分
题目描述
一棵树有N个结点,刚开始,每条边的权值都是0。有M个操作,每个操作是如下两种操作之一:
1、格式是:P A B,表示结点A到结点B的路径上的所有边的权值都增加1。
2、格式是:Q A B,表示询问结点A和结点B之间的那条边的权值是多少,结点A和结点B是相邻结点。
输入格式
第1行,N和M。2 <= N <= 100000, 1 <= M <= 100000。
第2..N行,每行两个整数u,v。表示结点u和结点v之间有一条边。
接下来有M行,每行是上文提及的一种操作。
输出格式
依次按照输入数据的询问次序,每个询问输出一个整数。
输入样例
4 6
1 4
2 4
3 4
P 2 3
P 1 3
Q 3 4
P 1 4
Q 2 4
Q 1 4
输出样例
2
1
2
本题中n=10^5可能会爆栈,于是我们来三次Bfs,第一次顺序求出fa,dep;第二次通过队列从后往前求size,son;第三次再按队列从前往后求top,w,就不会爆栈了,详见代码吧。
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=100100;
struct data
{
int obj,_Next;
} e[maxn<<1];
int head[maxn];
int cur=-1;
int tree[maxn<<2];
int que[maxn];
int he,ta;
int fa[maxn];
int dep[maxn];
int _Size[maxn];
int _Son[maxn];
int top[maxn];
int w[maxn];
int n,m;
void Add(int x,int y)
{
cur++;
e[cur].obj=y;
e[cur]._Next=head[x];
head[x]=cur;
}
void Bfs1()
{
fa[1]=0;
dep[1]=1;
he=0,ta=1;
que[1]=1;
while (he<ta)
{
he++;
int node=que[he];
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node])
{
fa[son]=node;
dep[son]=dep[node]+1;
ta++;
que[ta]=son;
}
p=e[p]._Next;
}
}
}
void Bfs2()
{
for (int i=1; i<=n; i++) _Size[i]=1;
he=1,ta=n;
while (he<ta)
{
int son=que[ta];
ta--;
int node=fa[son];
_Size[node]+=_Size[son];
if (_Size[son]>_Size[_Son[node]]) _Son[node]=son;
}
}
void Bfs3()
{
w[1]=0;
top[1]=1;
for (int i=1; i<n; i++)
{
int node=que[i];
int heavy_son=_Son[node];
if (!heavy_son) continue;
w[heavy_son]=w[node]+1;
top[heavy_son]=top[node];
int _Time=w[node]+_Size[heavy_son]+1;
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( son!=fa[node] && son!=heavy_son )
{
w[son]=_Time;
_Time+=_Size[son];
top[son]=son;
}
p=e[p]._Next;
}
}
}
void Down(int root)
{
int left=root<<1;
int right=left|1;
tree[left]+=tree[root];
tree[right]+=tree[root];
tree[root]=0;
}
void Update(int root,int L,int R,int x,int y)
{
if ( y<L || R<x ) return;
if ( x<=L && R<=y )
{
tree[root]++;
return;
}
Down(root);
int left=root<<1;
int right=left|1;
int mid=(L+R)>>1;
Update(left,L,mid,x,y);
Update(right,mid+1,R,x,y);
}
void Plus(int u,int v)
{
if (u==v) return;
if (top[u]==top[v])
{
if (w[u]>w[v]) swap(u,v);
Update(1,1,n,w[u]+1,w[v]);
return;
}
if (dep[ top[u] ]<dep[ top[v] ]) swap(u,v);
Update(1,1,n,w[ top[u] ],w[u]);
u=fa[ top[u] ];
Plus(u,v);
}
int Query(int root,int L,int R,int x)
{
if ( x<L || R<x ) return 0;
if ( L==x && x==R ) return tree[root];
Down(root);
int left=root<<1;
int right=left|1;
int mid=(L+R)>>1;
int vl=Query(left,L,mid,x);
int vr=Query(right,mid+1,R,x);
return vl+vr;
}
int main()
{
freopen("1578.in","r",stdin);
freopen("1578.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
Add(a,b);
Add(b,a);
}
Bfs1();
Bfs2();
Bfs3();
n--;
//for (int i=1; i<=n+1; i++) printf("%d ",w[i]);
//printf("\n");
for (int i=1; i<=m; i++)
{
char c=getchar();
while ( c!='P' && c!='Q' ) c=getchar();
int a,b;
scanf("%d%d",&a,&b);
if (c=='P') Plus(a,b);
else
{
int son=a;
if (a==fa[b]) son=b;
int ans=Query(1,1,n,w[son]);
printf("%d\n",ans);
}
//for (int j=1; j<=10; j++) printf("%d ",tree[j]);
//printf("\n");
}
return 0;
}