题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=2243
算法讨论:
树链剖分把树放到线段树上。然后线段树的每个节点要维护的东西有左端点的颜色,右端点的颜色,以及是否被改变过颜色,和颜色段数。
向上合并的过程中,要注意如果左孩子的右端点和右孩子的左端点颜色相同,那么就要把颜色段数减一。
然后我们考虑询问的问题:
对于一个询问,我们是按深度从下向上跳着计算的,所以每次统计一个路径的时候,我们要记录上一次的路径的两端点的颜色,如果本次路径计算的
右端点颜色和上次的左端点颜色相同,那么答案就要减去1.(注意树链剖分,如果在同一次线段树内查询,那么节点深度小的一定线段树的编号也小)。
当两个节点跳到一个重链上的时候,那么此时两边的颜色都要进行判断一下。至于为啥,很好想吧。
题目代码:
1 #include <iostream> 2 #include <algorithm> 3 #include <cstring> 4 #include <cstdlib> 5 #include <cstdio> 6 #include <cctype> 7 8 using namespace std; 9 const int N = 100000 + 5; 10 inline int read() { 11 int x = 0; 12 char c = getchar(); 13 while(!isdigit(c)) c = getchar(); 14 while(isdigit(c)) { 15 x = x * 10 + c - '0'; 16 c = getchar(); 17 } 18 return x; 19 } 20 21 struct SegmentTree { 22 int lc, rc, l, r, tag, sz; 23 }Node[N * 4]; 24 struct Edge { 25 int from, to, next; 26 }edges[N << 1]; 27 28 char ss[3]; 29 int n, m, cnt, pos, co[N], lco, rco; 30 int fa[N], head[N], son[N], size[N]; 31 int num[N], top[N], depth[N], seg[N]; 32 33 void insert(int from, int to) { 34 ++ cnt; 35 edges[cnt].from = from; edges[cnt].to = to; 36 edges[cnt].next = head[from]; head[from] = cnt; 37 } 38 39 void dfs_1(int u, int f) { 40 fa[u] = f; size[u] = 1; 41 for(int i = head[u]; i; i = edges[i].next) { 42 int v = edges[i].to; 43 if(v != f) { 44 depth[v] = depth[u] + 1; 45 dfs_1(v, u); 46 size[u] += size[v]; 47 if(!son[u] || size[v] > size[son[u]]) 48 son[u] = v; 49 } 50 } 51 } 52 53 void dfs_2(int u, int ances) { 54 top[u] = ances; 55 num[u] = ++ pos; 56 seg[pos] = u; 57 if(!son[u]) return; 58 dfs_2(son[u], ances); 59 for(int i = head[u]; i; i = edges[i].next) { 60 int v = edges[i].to; 61 if(v != fa[u] && v != son[u]) { 62 dfs_2(v, v); 63 } 64 } 65 } 66 67 void pushdown(int o) { 68 if(Node[o].l == Node[o].r) return; 69 int l = o << 1, r = o << 1 | 1; 70 if(Node[o].tag) { 71 Node[l].tag = Node[r].tag = Node[o].tag; 72 Node[l].lc = Node[l].rc = Node[o].tag; 73 Node[r].lc = Node[r].rc = Node[o].tag; 74 Node[l].sz = Node[r].sz = 1; 75 Node[o].tag = 0; 76 } 77 } 78 79 void pushup(int o) { 80 if(Node[o].l == Node[o].r) return; 81 int l = o << 1, r = o << 1 | 1; 82 Node[o].lc = Node[l].lc; 83 Node[o].rc = Node[r].rc; 84 Node[o].sz = Node[l].sz + Node[r].sz - (Node[l].rc == Node[r].lc); 85 } 86 87 void build(int o, int l, int r) { 88 Node[o].l = l; Node[o].r = r; Node[o].tag = 0; 89 if(l == r) { 90 Node[o].sz = 1; 91 Node[o].lc = Node[o].rc = co[seg[l]]; 92 return; 93 } 94 int mid = (l + r) >> 1; 95 build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 96 pushup(o); 97 } 98 99 void update(int o, int l, int r, int v) { 100 if(Node[o].l == l && Node[o].r == r) { 101 Node[o].lc = Node[o].rc = v; 102 Node[o].tag = v; Node[o].sz = 1; 103 return; 104 } 105 int mid = (Node[o].l + Node[o].r) >> 1; 106 pushdown(o); 107 if(r <= mid) update(o << 1, l, r, v); 108 else if(l > mid) update(o << 1 | 1, l, r, v); 109 else { 110 update(o << 1, l, mid, v); 111 update(o << 1 | 1, mid + 1, r, v); 112 } 113 pushup(o); 114 } 115 116 int query(int o, int l, int r, int L, int R) { 117 if(Node[o].l == L) lco = Node[o].lc; 118 if(Node[o].r == R) rco = Node[o].rc; 119 if(Node[o].l == l && Node[o].r == r) { 120 return Node[o].sz; 121 } 122 int mid = (Node[o].l + Node[o].r) >> 1; 123 pushdown(o); 124 if(r <= mid) return query(o << 1, l, r, L, R); 125 else if(l > mid) return query(o << 1 | 1, l, r, L, R); 126 else { 127 return query(o << 1, l, mid, L, R) + query(o << 1 | 1, mid + 1, r, L, R) - (Node[o << 1].rc == Node[o << 1 | 1].lc); 128 } 129 pushup(o); 130 } 131 132 void Update(int x, int y, int z) { 133 int f1 = top[x], f2 = top[y]; 134 while(f1 != f2) { 135 if(depth[f1] < depth[f2]) { 136 swap(x, y); swap(f1, f2); 137 } 138 update(1, num[f1], num[x], z); 139 x = fa[f1]; f1 = top[x]; 140 } 141 if(depth[x] < depth[y]) { 142 update(1, num[x], num[y], z); 143 } 144 else { 145 update(1, num[y], num[x], z); 146 } 147 } 148 149 int Query(int x, int y) { 150 int f1 = top[x], f2 = top[y], res = 0; 151 int ans1 = -1, ans2 = -1; 152 while(f1 != f2) { 153 if(depth[f1] < depth[f2]) { 154 swap(f1, f2); swap(x, y); 155 swap(ans1, ans2); 156 } 157 res += query(1, num[f1], num[x], num[f1], num[x]); 158 if(rco == ans1) res --; ans1 = lco; 159 x = fa[f1]; f1 = top[x]; 160 } 161 if(depth[x] < depth[y]) { 162 swap(x, y); swap(ans1, ans2); 163 } 164 res += query(1, num[y], num[x], num[y], num[x]); 165 if(rco == ans1) res --; 166 if(lco == ans2) res --; 167 return res; 168 } 169 170 int main() { 171 int x, y, z; 172 n = read(); m = read(); 173 for(int i = 1; i <= n; ++ i) co[i] = read(); 174 for(int i = 1; i < n; ++ i) { 175 x = read(); y = read(); 176 insert(x, y); insert(y, x); 177 } 178 depth[1] = 1; 179 dfs_1(1, -1); dfs_2(1, 1); 180 build(1, 1, n); 181 for(int i = 1; i <= m; ++ i) { 182 scanf("%s", ss); 183 if(ss[0] == 'C') { 184 x = read(); y = read(); z = read(); 185 Update(x, y, z); 186 } 187 else { 188 x = read(); y = read(); 189 printf("%d\n", Query(x, y)); 190 } 191 } 192 return 0; 193 }