题目大意: 给出一棵树和多个询问,每次询问给出 l , r , z l,r,z l,r,z,求 ∑ i = l r d e p ( l c a ( i , z ) ) \sum_{i=l}^r dep(lca(i,z)) ∑i=lrdep(lca(i,z))。
题解
显然每个询问可以差分成 1 1 1 ~ r r r 与 z z z 的 l c a lca lca 减去 1 1 1 ~ l − 1 l-1 l−1 的 l c a lca lca,那么可以将这新的 2 q 2q 2q 个询问从小到大排序。
发现 d e p ( i ) dep(i) dep(i) 等于 1 1 1 到 i i i 路径上的点数,那么对于 d e p ( l c a ( x , y ) ) dep(lca(x,y)) dep(lca(x,y)),如果将 1 1 1 ~ x x x 的路径上的点权值 + 1 +1 +1,则 1 1 1 ~ y y y 路径上的点权和就是 d e p ( l c a ( x , y ) ) dep(lca(x,y)) dep(lca(x,y))。
根据这个转化,将询问排序后,依次加入每个点,求值的时候 1 1 1 到 z z z 路径上的点权和即为所求,用树剖 + + + 树状数组维护一下即可。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 100010
#define mod 201314
int n,q;
struct edge{int y,next;}e[maxn];
int first[maxn],len=0;
void buildroad(int x,int y){e[++len]=(edge){y,first[x]};first[x]=len;}
int fa[maxn],size[maxn],mson[maxn];
void dfs1(int x)
{
size[x]=1;
for(int i=first[x];i;i=e[i].next)
{
int y=e[i].y;if(y==fa[x])continue;
fa[y]=x;dfs1(y);size[x]+=size[y];
if(size[y]>size[mson[x]])mson[x]=y;
}
}
int id[maxn],tot=0,top[maxn];
void dfs2(int x,int tp)
{
id[x]=++tot;top[x]=tp;
if(mson[x])dfs2(mson[x],tp);
for(int i=first[x];i;i=e[i].next)
if(e[i].y!=fa[x]&&e[i].y!=mson[x])dfs2(e[i].y,e[i].y);
}
int tr1[maxn],tr2[maxn];
void add(int &x,int y){if(x+y>=mod)x=x+y-mod;else if(x+y<0)x=x+y+mod;else x=x+y;}
void tr_add(int *tr,int x,int y){for(;x<=maxn-10;x+=(x&-x))add(tr[x],y);}
int tr_sum(int *tr,int x){int re=0;for(;x;x-=(x&-x))add(re,tr[x]);return re;}
void change(int x,int y){tr_add(tr1,x,1);tr_add(tr2,x,x-1);tr_add(tr1,y+1,-1);tr_add(tr2,y+1,-y);}
int sum(int x){return (1ll*tr_sum(tr1,x)*x%mod-tr_sum(tr2,x)+mod)%mod;}
int getsum(int x,int y){return (sum(y)-sum(x-1)+mod)%mod;}
struct que{int x,y,pos,type;}ask[maxn];
int t=0,ans[maxn];
bool cmp(que x,que y){return x.x<y.x;}
void go1(int x){while(x)change(id[top[x]],id[x]),x=fa[top[x]];}
int go2(int x){int re=0;while(x)add(re,getsum(id[top[x]],id[x])),x=fa[top[x]];return re;}
int main()
{
scanf("%d %d",&n,&q);for(int i=2,fa;i<=n;i++)
scanf("%d",&fa),buildroad(fa+1,i); dfs1(1); dfs2(1,1);
for(int i=1,l,r,z;i<=q;i++)scanf("%d %d %d",&l,&r,&z),
ask[++t]=(que){r+1,z+1,i,1},ask[++t]=(que){l,z+1,i,-1};
sort(ask+1,ask+t+1,cmp);int now=1;
for(int i=1;i<=t;i++){
while(now<=n&&now<=ask[i].x)go1(now),now++;
add(ans[ask[i].pos],ask[i].type*go2(ask[i].y));
}
for(int i=1;i<=q;i++)printf("%d\n",ans[i]);
}