Description
给定一棵树和若干操作,每次可以选取树上任意两点之间的路径染成一种颜色,或是查询任意两点之间路径上有多少段颜色。
Solution
树链剖分不解释,主要分析线段树的维护,在线段树区间上我们维护区间左端的颜色、右端的颜色、整个区间颜色段的数量。那么当我们合并两个区间时,我们首先将两区间的颜色段数量相加,如果左区间右端的颜色等于右区间左端的颜色,那么答案减一。
剩下的就是线段树和树链剖分的基本操作,在此不再赘述。
Code
1 #include <iostream> 2 #include <cstdio> 3 using namespace std; 4 typedef long long ll; 5 inline int read() { 6 int ret = 0, op = 1; 7 char c = getchar(); 8 while (c < '0' || c > '9') { 9 if (c == '-') op = -1; 10 c = getchar(); 11 } 12 while (c <= '9' && c >= '0') { 13 ret = ret * 10 + c - '0'; 14 c = getchar(); 15 } 16 return ret * op; 17 } 18 struct node { 19 int next, to; 20 } a[100010 << 1]; 21 struct segment { 22 int lc, rc, cnt; 23 int tag; 24 } s[100010 << 2]; 25 int n, m, num, head[100010], in[100010]; 26 void add(int from, int to) { 27 a[++num].next = head[from]; 28 a[num].to = to; 29 head[from] = num; 30 } 31 int size[100010], top[100010], son[100010], dep[100010], f[100010]; 32 int seg[100010 << 2], rev[100010 << 2], tot; 33 void dfs1(int u, int fa) { 34 dep[u] = dep[fa] + 1; 35 f[u] = fa; 36 size[u] = 1; 37 for (int i = head[u]; i; i = a[i].next) 38 if (a[i].to != fa) { 39 dfs1(a[i].to, u); 40 size[u] += size[a[i].to]; 41 if (size[son[u]] < size[a[i].to]) son[u] = a[i].to; 42 } 43 return ; 44 } 45 void dfs2(int u, int fa) { 46 if (son[u]) { 47 top[son[u]] = top[u]; 48 seg[son[u]] = ++tot; 49 rev[tot] = son[u]; 50 dfs2(son[u], u); 51 } 52 for (int i = head[u]; i; i = a[i].next) 53 if (!top[a[i].to]) { 54 top[a[i].to] = a[i].to; 55 seg[a[i].to] = ++tot; 56 rev[tot] = a[i].to; 57 dfs2(a[i].to, a[i].to); 58 } 59 } 60 void pushup(int now) { 61 s[now].lc = s[now << 1].lc; 62 s[now].rc = s[now << 1 | 1].rc; 63 s[now].cnt = s[now << 1].cnt + s[now << 1 | 1].cnt; 64 if (s[now << 1].rc == s[now << 1 | 1].lc) s[now].cnt --; 65 } 66 void pushdown(int now) { 67 if (s[now].tag != -1) { 68 s[now << 1].lc = s[now << 1].rc = s[now << 1 | 1].lc = s[now << 1 | 1].rc = s[now].tag; 69 s[now << 1].cnt = s[now << 1 | 1].cnt = 1; 70 s[now << 1].tag = s[now << 1 | 1].tag = s[now].tag; 71 s[now].tag = -1; 72 } 73 } 74 void build(int now, int l, int r) { 75 s[now].tag = -1; 76 if (l == r) { 77 s[now].lc = s[now].rc = in[rev[l]]; 78 s[now].cnt = 1; 79 return ; 80 } 81 int mid = l + r >> 1; 82 build(now << 1, l, mid); 83 build(now << 1 | 1, mid + 1, r); 84 pushup(now); 85 return ; 86 } 87 void update(int now, int l, int r, int x, int y, int val) { 88 if (x == l && r == y) { 89 s[now].lc = s[now].rc = s[now].tag = val; 90 s[now].cnt = 1; 91 return ; 92 } 93 int mid = l + r >> 1; 94 pushdown(now); 95 if (y <= mid) update(now << 1, l, mid, x, y, val); 96 else if (x > mid) update(now << 1 | 1, mid + 1, r, x, y, val); 97 else { 98 update(now << 1, l, mid, x, mid, val); 99 update(now << 1 | 1, mid + 1, r, mid + 1, y, val); 100 } 101 pushup(now); 102 return ; 103 } 104 void find(int x, int y, int val) { 105 while (top[x] != top[y]) { 106 if (dep[top[x]] < dep[top[y]]) swap(x, y); 107 update(1, 1, tot, seg[top[x]], seg[x], val); 108 x = f[top[x]]; 109 } 110 if (dep[x] > dep[y]) swap(x, y); 111 update(1, 1, tot, seg[x], seg[y], val); 112 } 113 int query(int now, int l, int r, int x, int y) { 114 if (x == l && r == y) return s[now].cnt; 115 pushdown(now); 116 int mid = l + r >> 1; 117 if (y <= mid) return query(now << 1, l, mid, x, y); 118 else if (x > mid) return query(now << 1 | 1, mid + 1, r, x, y); 119 else { 120 int retl = query(now << 1, l, mid, x, mid); 121 int retr = query(now << 1 | 1, mid + 1, r, mid + 1, y); 122 int ret = retl + retr; 123 if (s[now << 1].rc == s[now << 1 | 1].lc) ret--; 124 return ret; 125 } 126 } 127 int query(int now, int l, int r, int x) { 128 if (l == r) return s[now].lc; 129 int mid = l + r >> 1; 130 pushdown(now); 131 if (x <= mid) return query(now << 1, l, mid, x); 132 else return query(now << 1 | 1, mid + 1, r, x); 133 } 134 int find(int x, int y) { 135 int ans = 0; 136 while (top[x] != top[y]) { 137 if (dep[top[x]] < dep[top[y]]) swap(x, y); 138 ans += query(1, 1, tot, seg[top[x]], seg[x]); 139 if (query(1, 1, tot, seg[top[x]]) == query(1, 1, tot, seg[f[top[x]]])) ans--; 140 x = f[top[x]]; 141 // cout << ans << endl; 142 } 143 if (dep[x] > dep[y]) swap(x, y); 144 ans += query(1, 1, tot, seg[x], seg[y]); 145 // cout << ans << endl; 146 return ans; 147 } 148 int main() { 149 n = read(); m = read(); 150 for (int i = 1; i <= n; ++i) in[i] = read(); 151 for (int i = 1; i < n; ++i) { 152 int x = read(), y = read(); 153 add(x, y); add(y, x); 154 } 155 dfs1(1, 0); 156 seg[1] = tot = 1; 157 rev[1] = 1; 158 top[1] = 1; 159 dfs2(1, 1); 160 build(1, 1, tot); 161 while (m--) { 162 char op; 163 int x, y, z; 164 cin >> op; 165 cin >> x >> y; 166 if (op == 'C') { 167 cin >> z; 168 find(x, y, z); 169 } 170 else { 171 printf("%d\n", find(x, y)); 172 } 173 } 174 return 0; 175 }