题意:给定一棵树,边上有权,每个点有一个颜色A,q次询问,每次询问z,x,y表示颜色在[x,y]的所有点到点z的距离之和。
数据范围:满足 n<=150000,Q<=200000。对于所有数据,满足 A<=10^9
感觉这题非常难啊完全不会啊看上去是个动态树分治可是想了半天啥都没想到。。。。。
考虑dis(u,v)=dep(u)+dep(v)-2*dep(lca(u,v))(dis(u,v)表示u,v的距离,dep(x)表示根节点到x的距离),假如没有颜色的限制,那么问题就变成所有点到x的距离,假如我们从上面的式子进行考虑,u是确定的,v(就是1到n所有的节点)。dep(u)可以通过dfs求出,dep(v)(其实就是所有点到根的距离)可以通过dfs之后累加,关键点就是求x和v(其中1<=v<=n)的lca到根的距离和。
具体做法就是我们进行树链剖分,用线段树维护,线段树叶子节点维护的是对应的点和它的父亲这条路径被经过的距离和,非叶子节点维护的是区间和,那么对于一个节点x(1<=x<=n),我们一路跳到根节点,期间把每个点都加上这个点和这个点父亲的距离,一直加到根,区间加可以用线段树来实现(每次对x到top(x)进行区间加),具体做法的话假设我们做到点x,它的树链剖分序是t,则sum[t]=dis(x,fa(x)),然后我们需要维护sum的前缀和,这样在线段树区间加的时候,若是当前区间被目标区间完全包含则对对应节点加上sum[q]-sum[p-1]假设[p,q]是当前递归到的节点。查询x点和其他所有点的lca到根的距离和的时候,我们还是一路跳,把树上每个点的权值加上去,该过程可以用线段树优化。
考虑为什么这样做是对的,对于一个点x,它对它和它的祖先(设为y)的lca必然就是它的祖先本身,然后这样的贡献就是相当于对于每个y都加上dis(1,y),即每个y和它父亲这条路径要多经过一次。然后最后查询相当于是查一个点到根的路径上每条边被经过了的距离和。
因为颜色很大,所以离散化之后用可持久化线段树。
可持久化线段树区间加的下放标记要分2种类型,难写速度慢内存大。以下介绍一种简便方法:
设当前区间[p,q],目标区间[l,r],则我们在除了完全包含的区间外的所有区间都加上sum(min(r,q))-sum(max(l,p)-1)(这步相当于维护了一个区间和),否则我们另外开数组sum1对当前节点加1;查询的时候,每次走到一个被部分包含的节点就加上sum1*sum(min(r,q))-sum(max(l,p)-1),若走到完全包含的节点则再加上sum。证明请读者自行思考。
这傻逼被这题卡了一天。。
代码:
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN = 150005;
const int NOD = 10000005;
int first[MAXN], next[MAXN << 1], go[MAXN << 1], way[MAXN << 1], t;
int dis[MAXN], dep[MAXN], fa[MAXN], size[MAXN], top[MAXN], son[MAXN], pos[MAXN];
int n, m, i, j, k, x, y, z, A, b[MAXN], len, p, f[MAXN];
int sum1[NOD], lc[NOD], rc[NOD], root[MAXN];
long long ans, tot, num[MAXN], sum[NOD];
struct sb{
int point, color;
};
sb a[MAXN];
inline bool rule(const sb &a, const sb &b)
{
return a.color < b.color;
}
inline int get()
{
char c;
while ((c = getchar()) < 48 || c > 57);
int res = c - 48;
while ((c = getchar()) >= 48 && c <= 57)
res = res * 10 + c - 48;
return res;
}
inline void add(const int &x, const int &y, const int &z)
{
next[++t] = first[x]; first[x] = t; go[t] = y; way[t] = z;
next[++t] = first[y]; first[y] = t; go[t] = x; way[t] = z;
}
inline void dfs(int now)
{
size[now] = 1;
int son1 = 0, son2 = 0;
for(int i = first[now]; i; i = next[i])
if (fa[now] != go[i])
{
fa[go[i]] = now;
dep[go[i]] = dep[now] + way[i];
dis[go[i]] = way[i] + dis[now];
dfs(go[i]);
size[now] += size[go[i]];
if (size[go[i]] > son1) son1 = size[go[i]], son2 = go[i];
}
son[now] = son2;
}
inline void dfs1(int now)
{
pos[now] = ++t;
dis[t] = dep[now] - dep[fa[now]];
if (son[now])
{
top[son[now]] = top[now];
dfs1(son[now]);
}
for(int i = first[now]; i; i = next[i])
if (!pos[go[i]])
{
top[go[i]] = go[i];
dfs1(go[i]);
}
}
inline void insert(int &x, int y, int p, int q, int l, int r)
{
x = ++t;
lc[x] = lc[y];
rc[x] = rc[y];
sum[x] = sum[y];
sum1[x] = sum1[y];
if (p >= l && q <= r)
{
sum1[x] ++;
return;
}
sum[x] += dis[min(r, q)] - dis[max(l, p) - 1];
int mid = (p + q) >> 1;
if (mid >= l) insert(lc[x], lc[y], p, mid, l, r);
if (mid < r) insert(rc[x], rc[y], mid + 1, q, l, r);
}
inline void find(int k, int p, int q, int l, int r)
{
if (sum1[k]) tot += sum1[k] * (long long)(dis[min(r, q)] - dis[max(l, p) - 1]);
if (p >= l && q <= r)
{
tot += sum[k];
return;
}
int mid = (p + q) >> 1;
if (mid >= l) find(lc[k], p, mid, l, r);
if (mid < r) find(rc[k], mid + 1, q, l, r);
}
inline long long solve(int x, int y)
{
tot = 0;
while (x)
{
find(root[y], 1, n, pos[top[x]], pos[x]);
x = fa[top[x]];
}
return tot;
}
int main()
{
cin >> n >> m >> A;
for(i = 1; i <= n; i ++)
a[i].color = get(), a[i].point = i, b[++len] = a[i].color;
sort(a + 1, a + 1 + n, rule);
sort(b + 1, b + 1 + len);
len = unique(b + 1, b + 1 + len) - 1 - b;
for(i = 1; i < n; i ++)
{
x = get(); y = get(); z = get();
add(x, y, z);
}
dfs(1);
t = 0; top[1] = 1;
dfs1(1);
t = 0;
a[0].color = -1;
for(i = 2; i <= n; i ++)
dis[i] += dis[i - 1];
for(i = 1; i <= n; i ++)
{
if (a[i].color != a[i - 1].color) p ++, root[p] = root[p - 1];
f[p] ++;
num[p] += dep[a[i].point];
x = a[i].point;
while (x)
{
insert(root[p], root[p], 1, n, pos[top[x]], pos[x]);
x = fa[top[x]];
}
}
for(i = 2; i <= len; i ++)
num[i] += num[i - 1], f[i] += f[i - 1];
while (m --)
{
z = get(); x = get(); y = get();
x = (x + ans) % A;
y = (y + ans) % A;
if (x > y) swap(x, y);
x = lower_bound(b + 1, b + 1 + len, x) - b;
y = upper_bound(b + 1, b + 1 + len, y) - b;
y --;
ans = (f[y] - f[x - 1]) * (long long)dep[z] - 2 * (solve(z, y) - solve(z, x - 1)) + num[y] - num[x - 1];
printf("%lld\n", ans);
}
}