Description
给定一棵N个结点的树,每个点一开始都是白色,进行Q次操作,操作有以下两种:
1、给定一个节点x,把x染成蓝色
2、给定一个节点x,询问x到其他所有蓝色点的距离和输入N, Q,startSeed,threshold, maxDist。
用以下方法生成这棵树以及Q次操作:
int curValue = startSeed;
int genNextRandom() {
curValue = (curValue * 1999 + 17) % 1000003;
return curValue;
}
void generateInput() {
for (int i = 0; i < N-1; i++) {
distance[i] = genNextRandom() % maxDist;
parent[i] = genNextRandom();
if (parent[i] < threshold) {
parent[i] = i;
} else {
parent[i] = parent[i] % (i + 1);
}
}
for (int i = 0; i < Q; i++) {
queryType[i] = genNextRandom() % 2 + 1;
queryNode[i] = genNextRandom() % N;
}
}
以上程序输出四个数组:parent,distance,queryType以及queryNode。
其中parent、distance有N-1个元素,对于每个i(0<=i<=n-2),(i+1)与parent[i]有一条边相连,长度为distance[i]。注意0<=parent[i]<=i。
queryType、queryNode有Q个元素,对于每个i,操作种类是queryType[i],操作的节点是queryNode[i].
输出所有操作2答案的异或值。
Input
输入一行,包含N, Q,startSeed,threshold, maxDist
Output
输出所有操作2答案的异或值。
Sample Input
输入1:
4 6 15 2 5
输入2:
4 5 2 9 10
输入3:
14750 50 29750 1157 21610
Sample Output
输出1:
7
输出2:
30
输出3:
2537640
Data Constraint
2<=N<=100,000
1<=Q<=100,000
0<=startSeed<=1,000,002
0<=threshold<=1,000,003
1<=maxDist<=1,000,003
Solution
由于查询的是一个点到所有其他蓝点路径长度和,如果我们考虑了树上任意一对点的路径也就考虑了所有的询问,于是考虑点分治。
首先离线询问,对于每个点记录其第一次染成蓝色的时间 t i t_i ti,把对某个点的询问放在该点上,并记录询问的时间 q i q_i qi。点分治时,把树状数组 t i t_i ti位置加上 i i i到当前点的距离,同时再把个数用树状数组统计一下,然后枚举每个点,枚举该点上的询问,在树状数组里查找在这次询问之前的蓝点的贡献,记录答案。但是这样会有两点在同一子树内的情况,我们对于各子树删去这种情况即可。
复杂度 O ( ( n + q ) l o g 2 n ) O((n+q)log^2n) O((n+q)log2n)。
Code
#include <cstdio>
#include <cstring>
typedef long long ll;
const int N = 100007, INF = 0x3f3f3f3f;
int min(int a, int b) { return a < b ? a : b; }
int max(int a, int b) { return a > b ? a : b; }
int n, q;
int curValue, startSeed, threshold, maxDist, distance[N], parent[N], queryType[N], queryNode[N];
int genNextRandom()
{
curValue = (curValue * 1999 + 17) % 1000003;
return curValue;
}
void generateInput()
{
curValue = startSeed;
for (int i = 0; i < n - 1; i++)
{
distance[i] = genNextRandom() % maxDist;
parent[i] = genNextRandom();
if (parent[i] < threshold) parent[i] = i;
else parent[i] = parent[i] % (i + 1);
}
for (int i = 0; i < q; i++)
{
queryType[i] = genNextRandom() % 2 + 1;
queryNode[i] = genNextRandom() % n;
}
}
int p, sum;
int tot, st[N], to[N << 1], nx[N << 1], size[N], mxsiz[N], del[N], tim[N], alen, arr[N];
ll len[N << 1], dis[N], ans[N], ret;
void add(int u, int v, ll w) { if (!u || !v) return; to[++tot] = v, nx[tot] = st[u], len[tot] = w; st[u] = tot; }
int cnt, head[N], link[N], next[N];
void insert(int u, int id) { link[++cnt] = id, next[cnt] = head[u], head[u] = cnt; }
void getp(int u, int from)
{
size[u] = 1, mxsiz[u] = 0;
for (int i = st[u]; i; i = nx[i])
if (to[i] != from && !del[to[i]])
getp(to[i], u), size[u] += size[to[i]], mxsiz[u] = max(mxsiz[u], size[to[i]]);
mxsiz[u] = max(mxsiz[u], sum - size[u]);
if (mxsiz[u] < mxsiz[p]) p = u;
}
ll tr[N][2];
void plus(int po, ll val, int k) { for (; po <= q + 1; po += (po & (-po))) tr[po][k] += val; }
ll getsum(int po, int k) { ll ret = 0; for (; po; po -= (po & (-po))) ret += tr[po][k]; return ret; }
void clear(int po) { for (; po <= q + 1; po += (po & (-po))) tr[po][0] = tr[po][1] = 0; }
void getdis(int u, int from)
{
arr[++alen] = u;
for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) dis[to[i]] = dis[u] + len[i], getdis(to[i], u);
}
void calc(int u, int val, int t)
{
alen = 0, dis[u] = val, getdis(u, 0);
for (int i = 1; i <= alen; i++) if (tim[arr[i]] < INF) plus(tim[arr[i]], 1, 0), plus(tim[arr[i]], dis[arr[i]], 1);
for (int i = 1; i <= alen; i++)
{
int w = arr[i];
for (int j = head[w]; j; j = next[j]) ans[link[j]] += t * (dis[w] * getsum(link[j] - 1, 0) + getsum(link[j] - 1, 1));
}
for (int i = 1; i <= alen; i++) if (tim[arr[i]] < INF) clear(tim[arr[i]]);
}
void solve(int u)
{
calc(u, 0, 1); //总共统计一次答案
del[u] = 1;
for (int i = st[u]; i; i = nx[i]) if (!del[to[i]]) calc(to[i], len[i], -1); //各子树分别删去答案
for (int i = st[u]; i; i = nx[i])
if (!del[to[i]])
{
sum = size[to[i]], p = 0;
getp(to[i], 0), solve(p);
}
}
int main()
{
scanf("%d%d%d%d%d", &n, &q, &startSeed, &threshold, &maxDist);
generateInput();
for (int i = 0; i < n - 1; i++) add(i + 2, parent[i] + 1, distance[i]), add(parent[i] + 1, i + 2, distance[i]);
memset(tim, 0x3f, sizeof(tim));
for (int i = 0; i < q; i++)
if (queryType[i] == 1) tim[queryNode[i] + 1] = min(tim[queryNode[i] + 1], i + 1);
else insert(queryNode[i] + 1, i + 1);
mxsiz[0] = INF;
sum = n, p = 0;
getp(1, 0);
solve(p);
ret = 0;
for (int i = 1; i <= q; i++) ret ^= ans[i];
printf("%lld\n", ret);
return 0;
}