最近总是在做树链剖分的题(觉得有必要学一下倍增算法=_=)。这题也是一个树链剖分。维护和找答案的时候注意区间左右端点的颜色就OK了……
上代码:
#include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> #include <iostream> #include <cmath> #define N 100010 #define inf 0x7f7f7f7f using namespace std; struct sss { int color, lc, rc; int num, push; }t[N*4]; int n, m, color[N], nowplace = 0; int fa[N], deep[N], w[N], top[N], son[N] = {0}, siz[N] = {0}; int p[N] = {0}, v[N*2], next[N*2], bnum = 0; void build_tree(int now, int l, int r) { t[now].push = 0; t[now].num = 0; t[now].color = t[now].lc = t[now].rc = 0; if (l == r) return; int mid = (l+r) /2 ; build_tree(now*2, l, mid); build_tree(now*2+1, mid+1, r); } void update(int now) { t[now].num = t[now*2].num + t[now*2+1].num; if (t[now*2].rc == t[now*2+1].lc) t[now].num--; t[now].lc = t[now*2].lc; t[now].rc = t[now*2+1].rc; if (t[now*2].color && t[now*2+1].color == t[now*2].color) t[now].color = t[now*2].color; else t[now].color = 0; } void downdate(int now) { if (!t[now].push) return; t[now*2].push = t[now*2+1].push = 1; t[now].push = 0; t[now*2].color = t[now*2+1].color = t[now].color; t[now*2].lc = t[now*2].rc = t[now*2+1].rc = t[now*2+1].lc = t[now].color; t[now*2].num = t[now*2+1].num = 1; } int findcolor(int now, int l, int r, int place) { if (l == r) return t[now].color; int mid = (l+r) / 2; downdate(now); if (place <= mid) return findcolor(now*2, l, mid, place); else return findcolor(now*2+1, mid+1, r, place); } void make_tree(int now, int l, int r) { if (l == r) { t[now].color = t[now].lc = t[now].rc = 0; t[now].num = 0; t[now].push = 0; return; } int mid = (l+r) / 2; make_tree(now*2, l, mid); make_tree(now*2+1, mid+1, r); } void tchange(int now, int l, int r, int cl, int cr, int cnum) { if (cl <= l && r <= cr) { t[now].color = cnum; t[now].lc = t[now].rc = cnum; t[now].num = 1; t[now].push = 1; return; } int mid = (l+r) / 2; downdate(now); if (cl <= mid) tchange(now*2, l, mid, cl, cr, cnum); if (cr > mid) tchange(now*2+1, mid+1, r, cl, cr, cnum); update(now); } int task(int now, int l, int r, int al, int ar) { if (al <= l && r <= ar) return t[now].num; int mid = (l+r) / 2, ans = 0; downdate(now); if (al <= mid) ans += task(now*2, l, mid, al, ar); if (ar > mid) ans += task(now*2+1, mid+1, r, al, ar); if (al <= mid && ar > mid) if (findcolor(1, 1, n, mid) == findcolor(1, 1, n, mid+1)) ans--; return ans; } void addbian(int x, int y) { bnum++; next[bnum] = p[x]; p[x] = bnum; v[bnum] = y; bnum++; next[bnum] = p[y]; p[y] = bnum; v[bnum] = x; } void dfs_1(int now, int fat, int nowdeep) { int k = p[now]; fa[now] = fat; deep[now] = nowdeep; int maxson = 0; siz[now] = 1; while (k) { if (v[k] != fat) { dfs_1(v[k], now, nowdeep+1); siz[now] += siz[v[k]]; if (siz[v[k]] > maxson) { maxson = siz[v[k]]; son[now] = v[k]; } } k = next[k]; } } void dfs_2(int now, int fat, int nowtop) { int k = p[now]; w[now] = ++nowplace; top[now] = nowtop; tchange(1, 1, n, w[now], w[now], color[now]); if (son[now]) dfs_2(son[now], now, nowtop); while (k) { if (v[k] != fat && v[k] != son[now]) dfs_2(v[k], now, v[k]); k = next[k]; } } void change(int u, int v, int changenum) { int f1 = top[u], f2 = top[v]; if (deep[f1] < deep[f2]) { swap(f1, f2); swap(u, v); } if (f1 == f2) { if (u == v) tchange(1, 1, n, w[u], w[u], changenum); else tchange(1, 1, n, min(w[u], w[v]), max(w[u], w[v]), changenum); } else { tchange(1, 1, n, w[f1], w[u], changenum); change(fa[f1], v, changenum); } } int ask(int u, int v) { int f1 = top[u], f2 = top[v], ans = 0; if (deep[f1] < deep[f2]) { swap(f1, f2); swap(u, v); } if (f1 == f2) { if (u == v) return ans+1; else return task(1, 1, n, min(w[u], w[v]), max(w[u], w[v])); } else { ans += task(1, 1, n, w[f1], w[u]); int x = findcolor(1, 1, n, w[f1]), y = findcolor(1, 1, n, w[fa[f1]]); if (x == y) ans --; ans += ask(fa[f1], v); return ans; } } int main() { scanf("%d%d", &n, &m); build_tree(1, 1, n); for (int i = 1; i <= n; ++i) scanf("%d", &color[i]); for (int i = 1; i < n; ++i) { int x, y; scanf("%d%d", &x, &y); addbian(x, y); } dfs_1(1, 0, 1); dfs_2(1, 0, 1); for (int i = 1; i <= m; ++i) { char s[2]; scanf("%s", s); if (s[0] == 'C') { int a, b, c; scanf("%d%d%d", &a, &b, &c); change(a, b, c); } else { int a, b; scanf("%d%d", &a, &b); printf("%d\n", ask(a, b)); } } return 0; }