题目
给定一颗n(n≤150000)个点的树,每个点有点权,边有边权(表示两个点之间的距离)。q(q≤200000)次询问,每次询问点权在[L,R]之间的所有点到某个点的距离之和。强制在线。
思路
考虑建立点分树,分治结构每个点都储存对应分治范围(简称范围)内的信息,询问时从c向后跳分治链,并逐级将信息合并得到整棵树的答案。
对于链上某一点x,设前一个点为px,显然我们需要用到的,x范围中(除去px范围)的所有合法的点(即点权在[l,r]内的点)到c的距离之和,拆开为这些点到x的距离和+这些点的个数乘以x到c的距离。对于第一部分我们将它看作x范围内所有点到x的距离和-px范围内所有点到x的距离和,第二部分同理。
这样,我们只需要对每个节点x储存范围内所有点点权age、到x的距离dis以及到上一级分治中心fa[x]的距离ldis(放在集合d[x])。然后将d[x]按照age排序,那么x的分治范围对于查询中心c的产生贡的点献处于d[x]的一段区间上。利用前(后)缀和+二分就能地很好处理了。
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=150010,M=N;
int n,q,A;ll ans;
EE(1); int x[N];
int siz[N],S,root,mxr;bool vis[N];
void get_siz(int u,int fa)
{
siz[u]=1;
for(int i=head[u],v; i; i=e[i].nxt) if((v=e[i].to) != fa && !vis[v])
get_siz(v,u),siz[u] += siz[v];
}
void get_root(int u,int fa)
{
int maxx=S - siz[u];
for(int i=head[u],v; i; i=e[i].nxt) if((v=e[i].to) != fa && !vis[v])
get_root(v,u),maxx=max(maxx,siz[v]);
if(maxx < mxr) mxr=maxx,root=u;
}
struct anc { int to;ll dis; int ty; };vector<anc> v[N];
struct data
{
int w;ll num,sum;
friend bool operator < (data a,data b) { return a.w < b.w; }
};vector<data> s[N][3];
int dep[N];
void dfs(int u,int fa,int fr,int w)
{
v[u].push_back({fr,dep[u],w});
s[fr][w].push_back({x[u],1,dep[u]});
for(int i=head[u],v; i; i=e[i].nxt) if((v=e[i].to) != fa && !vis[v])
dep[v]=dep[u] + e[i].w,dfs(v,u,fr,w);
}
void solve(int u)
{
get_siz(u,0),S=siz[u],mxr=INF,get_root(u,0);
vis[root]=1,v[root].push_back({root,0,-1});
if(siz[u] == 1) return;
int cnt=0;
for(int i=head[root],v; i; i=e[i].nxt) if(!vis[v=e[i].to])
{
dep[v]=e[i].w,dfs(v,root,root,cnt);
auto& now=s[root][cnt];
now.push_back({INF,0,0});
sort(now.begin(),now.end());
for(int j=now.size() - 2;~j;j--)
now[j].num += now[j + 1].num,now[j].sum += now[j + 1].sum;
cnt++;
}for(int i=head[root],v; i; i=e[i].nxt) if(!vis[v=e[i].to]) solve(v);
}
ll query(int l,int r,int u)
{
ll ans=0;
for(int i=v[u].size() - 1;~i; i--)
{
int fa=v[u][i].to;
for(int j=0;j < 3;j++)
{
auto& now=s[fa][j];
if(j == v[u][i].ty || now.empty()) continue;
auto L=lower_bound(now.begin(),now.end(),(data){l,0,0});
auto R=upper_bound(now.begin(),now.end(),(data){r,0,0});
ans += v[u][i].dis * (L->num - R->num) + L->sum - R->sum;
}if(l <= x[fa] && r >= x[fa]) ans += v[u][i].dis;
}
return ans;
}
int main()
{
n=read(),q=read(),A=read();
for(int i=1; i <= n; i++) x[i]=read();
for(int i=1; i < n; i++) { int x=read(),y=read();add_edge(x,y,read()); }
solve(1);
for(int i=1; i <= q; i++)
{
ll u=read(),l=read(),r=read();
l=(l + ans) % A,r=(r + ans) % A; if(l > r) swap(l,r);
fprint(ans=query(l,r,u));
}
}