Description
给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)
Input
第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。
Output
输出q行,每行表示一个询问的答案。每个答案对201314取模输出
这道题不是一般的妙啊。
开始的想法是把z到根的路径打上标记,然后枚举l到r的点,向上跳到第一个有标记的点,答案加上这个点的深度。
关键是:这个点的深度就是它到根路径上的点的数量!
所以可以把z到根路径上每个点+1,然后询问l到r中每个点到根的路径上的权值和。
也可以反过来,把l到r中每个点到根上+1,询问根到z的路径和。
显然可以离线,记录一个类似于前缀和的东西就好了。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define mod 201314
using namespace std;
struct edge{
int to,next;
}ed[100010];
struct node{
int l,r,sl,sr,ind,x;
}a[50010];
int head[50010],size=0,cnt=0,ind[50010],top[50010],deep[50010],siz[50010],son[50010],fa[50010];
int sum[200010],add[200010];
void addd(int from,int to)
{
size++;
ed[size].to=to;
ed[size].next=head[from];
head[from]=size;
}
bool cmp1(node a,node b)
{
return a.l<b.l;
}
bool cmp2(node a,node b)
{
return a.r<b.r;
}
bool cmp3(node a,node b)
{
return a.ind<b.ind;
}
void dfs1(int u)
{
siz[u]=1;
int Max=-1;
for(int i=head[u];i;i=ed[i].next)
{
int v=ed[i].to;
if(v==fa[u]) continue;
deep[v]=deep[u]+1;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>Max)
{
Max=siz[v];
son[u]=v;
}
}
}
void dfs2(int u,int tp)
{
ind[u]=++cnt;
top[u]=tp;
if(!son[u]) return;
dfs2(son[u],tp);
for(int i=head[u];i;i=ed[i].next)
{
int v=ed[i].to;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
void work()
{
fa[0]=-1;
dfs1(0);
dfs2(0,0);
}
void push_up(int root)
{
sum[root]=sum[root<<1]+sum[root<<1|1];
}
void push_down(int root,int nl,int nr)
{
sum[root<<1]+=nl*add[root];
sum[root<<1|1]+=nr*add[root];
add[root<<1]+=add[root];
add[root<<1|1]+=add[root];
add[root]=0;
}
void update(int root,int l,int r,int x,int y,int w)
{
if(x<=l&&y>=r)
{
sum[root]+=w*(r-l+1);
add[root]+=w;
return;
}
int mid=(l+r)>>1;
push_down(root,mid-l+1,r-mid);
if(x<=mid) update(root<<1,l,mid,x,y,w);
if(y>mid) update(root<<1|1,mid+1,r,x,y,w);
push_up(root);
}
int query(int root,int l,int r,int x,int y)
{
if(x<=l&&y>=r)
{
return sum[root];
}
int mid=(l+r)>>1;
push_down(root,mid-l+1,r-mid);
int ans=0;
if(x<=mid) ans+=query(root<<1,l,mid,x,y);
if(y>mid) ans+=query(root<<1|1,mid+1,r,x,y);
push_up(root);
return ans;
}
void Plus(int x)
{
while(top[x]!=0)
{
update(1,1,cnt,ind[top[x]],ind[x],1);
x=fa[top[x]];
}
update(1,1,cnt,ind[0],ind[x],1);
}
int get(int x)
{
int ans=0;
while(top[x]!=0)
{
ans+=query(1,1,cnt,ind[top[x]],ind[x]);
x=fa[top[x]];
ans%=mod;
}
ans+=query(1,1,cnt,ind[0],ind[x]);
return ans%mod;
}
void clear()
{
memset(add,0,sizeof(add));
memset(sum,0,sizeof(sum));
}
int main()
{
int n,q;
scanf("%d%d",&n,&q);
for(int i=1;i<n;i++)
{
int x;
scanf("%d",&x);
addd(i,x);
addd(x,i);
}
work();
for(int i=1;i<=q;i++)
{
scanf("%d%d%d",&a[i].l,&a[i].r,&a[i].x);
a[i].ind=i;
}
sort(a+1,a+q+1,cmp1);
int y=-1;
for(int i=1;i<=q;i++)
{
while(y<a[i].l-1) Plus(++y);
a[i].sl=get(a[i].x);
}
sort(a+1,a+q+1,cmp2);
y=-1;
clear();
for(int i=1;i<=q;i++)
{
while(y<a[i].r) Plus(++y);
a[i].sr=get(a[i].x);
}
sort(a+1,a+q+1,cmp3);
for(int i=1;i<=q;i++) printf("%d\n",(a[i].sr-a[i].sl+3*mod)%mod);
return 0;
}