You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
- u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M. (N <= 40000, M <= 100000)
In the second line there are N integers. The i-th integer denotes the weight of the i-th node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation, print its result.
Example
Input: 8 2 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 7 8
Output: 4 4
思路:本题的方法很多,这里我采用的是树上莫队。
1 #include<bits/stdc++.h> 2 using namespace std; 3 int const N = 40000 + 3; 4 int const M = 100000 + 3; 5 int a[N], h[N], cnt, bl, n, m, tin[N], tout[N], sum, f[N][16], ans[M], b[N], num, tot, vis[N], sq[N << 1], ct[N]; 6 struct edge { 7 int to, nt; 8 } e[N << 1]; 9 struct query { 10 int x, y, id, z; 11 bool operator < (const query &rhs) const { 12 if(x / bl != rhs.x / bl) return x / bl < rhs.x / bl; 13 else return y < rhs.y; 14 } 15 } q[M]; 16 void add(int a, int b) { 17 e[++cnt].to = b; 18 e[cnt].nt = h[a]; 19 h[a] = cnt; 20 } 21 void dfs(int x, int fa) { 22 tin[x] = ++sum; 23 sq[sum] = x; 24 f[x][0] = fa; 25 for(int i = h[x]; i; i = e[i].nt) { 26 int v = e[i].to; 27 if(v == fa) continue; 28 dfs(v, x); 29 } 30 tout[x] = ++sum; 31 sq[sum] = x; 32 } 33 int ancestor(int x, int y) { 34 return tin[x] <= tin[y] && tout[y] <= tout[x]; 35 } 36 int lca(int x, int y) { 37 if(ancestor(x, y)) return x; 38 if(ancestor(y, x)) return y; 39 for(int i = 15; i >= 0; i--) 40 if(!ancestor(f[x][i], y)) 41 x = f[x][i]; 42 return f[x][0]; 43 } 44 void ud(int x, int v) { 45 if(ct[a[x]] == 1 && v == -1) tot--; 46 if(ct[a[x]] == 0 && v == 1) tot++; 47 ct[a[x]] += v; 48 } 49 void t(int x) { 50 if(vis[sq[x]]) ud(sq[x], -1); 51 else ud(sq[x], 1); 52 vis[sq[x]] ^= 1; 53 } 54 int main() { 55 scanf("%d%d", &n, &m); 56 for(int i = 1; i <= n; i++) 57 scanf("%d", &a[i]), b[i] = a[i]; 58 sort(b + 1, b + n + 1); 59 num = unique(b + 1, b + n + 1) - b - 1; 60 for(int i = 1; i <= n; i++) 61 a[i] = lower_bound(b + 1, b + num + 1, a[i]) - b; 62 for(int i = 1; i < n; i++) { 63 int x, y; 64 scanf("%d%d", &x, &y); 65 add(x, y); 66 add(y, x); 67 } 68 dfs(1, 1); 69 for(int j = 1; j <= 15; j++) 70 for(int i = 1; i <= n; i++) 71 f[i][j] = f[f[i][j - 1]][j - 1]; 72 for(int i = 1; i <= m; i++) { 73 int x, y; 74 scanf("%d%d", &x, &y); 75 q[i].z = lca(x, y); 76 q[i].id = i; 77 if(tout[x] > tin[y]) swap(x, y); 78 q[i].x = tout[x]; 79 q[i].y = tin[y]; 80 } 81 bl = sqrt(2 * n); 82 sort(q + 1, q + m + 1); 83 int l = 1, r = 0; 84 for(int i = 1; i <= m; i++) { 85 while(l < q[i].x) t(l++); 86 while(l > q[i].x) t(--l); 87 while(r < q[i].y) t(++r); 88 while(r > q[i].y) t(r--); 89 ud(q[i].z, 1); 90 ans[q[i].id] = tot; 91 ud(q[i].z, -1); 92 } 93 for(int i = 1; i <= m; i++) 94 printf("%d\n", ans[i]); 95 return 0; 96 }