题意
有一棵树。
多组询问,每个询问一组 ( l , , r , z ) (l, ,r, z) (l,,r,z) , 求 ∑ i = l r d e p [ L c a ( i , x ) ] \sum_{i=l}^{r} dep[Lca(i, x)] ∑i=lrdep[Lca(i,x)] 。
思路
好题,做法多样。这里总结 3 种做法。
首先是我的大常数做法:点分治。这是一种离线做法,还需要卡常才能过。
点分治非常暴力,甚至不用转化思路,LCA 就是 LCA。把每个询问挂在 x x x 那个点上。对于每个重心,求连边跨越重心的点对的 LCA 对询问的贡献。
可以发现,找出树的重心之后,一些子树原本就是重心的子树,但是有一棵子树深度是比重心浅的。
分类讨论:对于深度比重心浅的那些子树之间的点对,LCA 肯定是重心。而比重心浅的那棵子树最多只可能有一棵,他其中的任意一个点 u u u 和其他的子树中的点 v v v 组成的点对的 LCA 肯定就是 u u u 和重心的 LCA。
如此一来只要一棵线段树,将所有节点的贡献写到线段树上(需要两棵)。要处理某棵子树中的询问之前,先把自己子树的贡献去掉,然后直接在线段树上询问就好了。
复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n) 。
然后是大众做法:树剖线段树。这是一种离线做法。
两个点的LCA可以看成将一个点到根的所有点都 +1 ,另一个点询问他到根的点权和。
把一个询问的 ( l , r ) (l,r) (l,r) 拆成 ( 1 , l − 1 ) (1,l-1) (1,l−1) 和 ( 1 , r ) (1, r) (1,r) 。按点的编号每次加入一个点,将他到根的所有点 +1,并处理以这个点为右端点的询问。
只需要树剖 + 线段树就可以做到 O ( n log 2 n ) O(n\log^2n) O(nlog2n) 的时间复杂度。
p.s. 我并不知道怎么把这种做法用主席树变成在线。
再来一种神仙做法:分块。这是一种在线做法。
思路的转化和上个做法一样。
对于点的编号(并不是 dfs 序)分块。对于一个块内的点,建树,差分,可以 O ( n n ) O(n\sqrt n) O(nn) 预处理, O ( 1 ) O(1) O(1) 查询一个点到根的点权和。对于剩下的边角料,直接暴力求 LCA 贡献答案,使用倍增 RMQ 可以做到 O ( n log n ) O(n\log n) O(nlogn) 预处理, O ( 1 ) O(1) O(1) 查询。
所以复杂度是 O ( n n ) O(n\sqrt n) O(nn) 。
代码
点分治写法
#include<bits/stdc++.h>
using namespace std;
const int mod = 201314, inf = 1e9 + 7;
const int N = 5e4 + 10, M = N<<1, E = 16;
namespace Graph
{
int h[N], ecnt, nxt[M], v[M];
inline void clear(){ecnt = 1;}
inline void add_dir(int _u, int _v){
v[++ecnt] = _v;
nxt[ecnt] = h[_u]; h[_u] = ecnt;
}
inline void add_undir(int _u, int _v){
add_dir(_u, _v);
add_dir(_v, _u);
}
}
using namespace Graph;
inline void add(int &x, int y){x += y; if (x >= mod) x -= mod;} // 卡常
inline int _add(int x, int y){x += y; if (x >= mod) x -= mod; return x;}
struct Sum_tr{
#define ls (u<<1)
#define rs (u<<1^1)
int t[N<<2];
inline void clear(){
memset(t, 0, sizeof(t));
}
inline void push_up(int u){
t[u] = _add(t[ls], t[rs]);
}
void modify(int u, int l, int r, int P, int X){
if (l == r){add(t[u], X); return;}
int mid = l + r >> 1;
if (P <= mid) modify(ls, l, mid, P, X);
else modify(rs, mid+1, r, P, X);
push_up(u);
}
int query(int u, int l, int r, int L, int R){
if (L <= l && r <= R) return t[u];
int mid = l + r >> 1, ret = 0;
if (L <= mid) add(ret, query(ls, l, mid, L, R));
if (mid < R) add(ret, query(rs, mid+1, r, L, R));
return ret;
}
#undef ls
#undef rs
}s, t;
int n, q;
vector<int> le[N], ri[N], id[N];
int f[N][E], dep[N], ans[N];
int zx, mn, sz, siz[N], val[N];
bool vis[N];
template<class T>inline void read(T &x){
x = 0; bool fl = 0; char c = getchar();
while (!isdigit(c)){if (c == '-') fl = 1; c = getchar();}
while (isdigit(c)){x = (x<<3)+(x<<1)+c-'0'; c = getchar();}
if (fl) x = -x;
}
void dfs(int u, int fa)
{
dep[u] = dep[fa]+1;
f[u][0] = fa;
for (int i = 1; i < E; ++ i)
f[u][i] = f[f[u][i-1]][i-1];
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa)
dfs(v[i], u);
}
int lca(int x, int y)
{
if (dep[x] < dep[y]) swap(x, y);
for (int i = E-1; i >= 0; -- i)
if (dep[f[x][i]] >= dep[y])
x = f[x][i];
if (x == y) return x;
for (int i = E-1; i >= 0; -- i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
void find_zx(int u, int fa)
{
int son = 0;
siz[u] = 1;
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa && !vis[v[i]]){
find_zx(v[i], u);
son = max(son, siz[v[i]]);
siz[u] += siz[v[i]];
}
son = max(son, sz-siz[u]);
if (son < mn) zx = u, mn = son;
}
void init(int u, int fa)
{
val[u] = dep[lca(u, zx)];
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa && !vis[v[i]])
init(v[i], u);
}
void modify(int u, int fa, int opt, bool fl)
{
if (fl || val[u] == dep[zx]) s.modify(1, 1, n, u, opt*val[u]); // 卡常
if (fl || val[u] != dep[zx]) t.modify(1, 1, n, u, opt);
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa && !vis[v[i]])
modify(v[i], u, opt, fl);
}
void calc(int u, int fa)
{
for (int i = 0, ii = le[u].size(); i < ii; ++ i){
if (val[u] == dep[zx])
add(ans[id[u][i]], s.query(1, 1, n, le[u][i], ri[u][i]));
else add(ans[id[u][i]], 1LL*t.query(1, 1, n, le[u][i], ri[u][i])*val[u]%mod);
}
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa && !vis[v[i]])
calc(v[i], u);
}
void solve(int u)
{
vis[u] = 1;
for (int i = h[u]; i; i = nxt[i])
if (!vis[v[i]])
init(v[i], u), modify(v[i], u, 1, 1);
s.modify(1, 1, n, u, dep[u]);
t.modify(1, 1, n, u, 1);
for (int i = 0, ii = le[u].size(); i < ii; ++ i)
add(ans[id[u][i]], s.query(1, 1, n, le[u][i], ri[u][i]));
for (int i = h[u]; i; i = nxt[i])
if (!vis[v[i]]){
modify(v[i], u, -1, 0);
calc(v[i], u);
modify(v[i], u, 1, 0);
}
for (int i = h[u]; i; i = nxt[i])
if (!vis[v[i]])
modify(v[i], u, -1, 1);
s.modify(1, 1, n, u, -dep[u]);
t.modify(1, 1, n, u, -1);
for (int i = h[u]; i; i = nxt[i])
if (!vis[v[i]]){
sz = siz[v[i]]; mn = inf;
find_zx(v[i], u);
solve(zx);
}
}
int main()
{
read(n); read(q);
clear();
for (int i = 2; i <= n; ++ i){
int x; read(x);
add_undir(x+1, i);
}
for (int i = 1; i <= q; ++ i){
int l, r, x;
read(l); read(r); read(x);
l++; r++; x++;
le[x].push_back(l);
ri[x].push_back(r);
id[x].push_back(i);
}
dep[0] = 0;
dfs(1, 0);
memset(vis, 0, sizeof(vis));
s.clear();
sz = n; mn = inf;
find_zx(1, 0);
solve(zx);
for (int i = 1; i <= q; ++ i)
printf("%d\n", ans[i]);
return 0;
}