先树剖一下,对每种宗教建一棵线段树。
当然,把所有的节点都建出来,在时空复杂度上都不允许。
所以考虑不建出所有的节点。主要想法是:一棵线段树里,只维护出部分叶子节点(信仰该宗教的城市)以及这些节点到线段树根的路径。可以知道,这样的建树复杂度为
O(nlogn)
。
对于操作
1
,就在该城市的宗教对应的线段树里删掉这个叶子节点,并在
对于操作
2,3,4
,就是单点修改和路径询问,就不多说了。
代码:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
inline char get() {
char c; while ((c = getchar()) != 'C' && c != 'W' && c != 'Q'
&& c != 'S' && c != 'M'); return c;
}
const int N = 1e5 + 5, M = 2e6 + 5;
int n, Q, W[N], C[N], fa[N], dep[N], sze[N], son[N], top[N], pos[N],
rt[N], QAQ, ecnt, nxt[N << 1], adj[N], go[N << 1], col[N], val[N],
QWQ;
void add_edge(int u, int v) {
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
struct cyx {
int lc, rc, Sum, Max;
cyx() {}
cyx(int val) :
Sum(val), Max(val), lc(0), rc(0) {}
} T[M];
void upt(int p) {
T[p].Sum = T[T[p].lc].Sum + T[T[p].rc].Sum;
T[p].Max = max(T[T[p].lc].Max, T[T[p].rc].Max);
}
void ins(int x, int w, int l, int r, int &p) {
if (!p) T[p = ++QWQ] = cyx(0);
if (l == r) return (void) (T[p].Sum = T[p].Max = w);
int mid = l + r >> 1;
if (x <= mid) ins(x, w, l, mid, T[p].lc);
else ins(x, w, mid + 1, r, T[p].rc); upt(p);
}
void dfs1(int u, int fu) {
fa[u] = fu; dep[u] = dep[fu] + 1; sze[u] = 1;
for (int e = adj[u], v; e; e = nxt[e]) {
if ((v = go[e]) == fu) continue; dfs1(v, u);
sze[u] += sze[v];
if (sze[v] > sze[son[u]]) son[u] = v;
}
}
void dfs2(int u, int fu) {
if (son[u]) {
top[son[u]] = top[u];
pos[son[u]] = ++QAQ;
dfs2(son[u], u);
}
for (int e = adj[u], v; e; e = nxt[e]) {
if ((v = go[e]) == fu || v == son[u]) continue;
top[v] = v; pos[v] = ++QAQ; dfs2(v, u);
}
}
void init() {
QAQ = top[1] = pos[1] = 1;
int i; dfs1(1, 0); dfs2(1, 0);
for (i = 1; i <= n; i++)
ins(pos[i], val[i], 1, n, rt[col[i]]);
}
void del(int x, int l, int r, int &p) {
if (l == r) return (void) (T[p] = cyx(0), p = 0);
int mid = l + r >> 1; if (x <= mid) del(x, l, mid, T[p].lc);
else del(x, mid + 1, r, T[p].rc); upt(p);
if (!T[p].lc && !T[p].rc) T[p] = cyx(0), p = 0;
}
void change(int l, int r, int x, int v, int p) {
if (l == r) return (void) (T[p].Sum = T[p].Max = v);
int mid = l + r >> 1;
if (x <= mid) change(l, mid, x, v, T[p].lc);
else change(mid + 1, r, x, v, T[p].rc); upt(p);
}
void changeC(int x, int c) {
del(pos[x], 1, n, rt[col[x]]); ins(pos[x], val[x], 1, n, rt[c]);
col[x] = c;
}
int querySum(int l, int r, int s, int e, int p) {
if (!p) return 0; if (l == s && r == e) return T[p].Sum;
int mid = l + r >> 1;
if (e <= mid) return querySum(l, mid, s, e, T[p].lc);
else if (s >= mid + 1) return querySum(mid + 1, r, s, e, T[p].rc);
else return querySum(l, mid, s, mid, T[p].lc)
+ querySum(mid + 1, r, mid + 1, e, T[p].rc);
}
int queryMax(int l, int r, int s, int e, int p) {
if (!p) return 0; if (l == s && r == e) return T[p].Max;
int mid = l + r >> 1;
if (e <= mid) return queryMax(l, mid, s, e, T[p].lc);
else if (s >= mid + 1) return queryMax(mid + 1, r, s, e, T[p].rc);
else return max(queryMax(l, mid, s, mid, T[p].lc),
queryMax(mid + 1, r, mid + 1, e, T[p].rc));
}
int pathSum(int u, int v) {
int res = 0, x = col[u];
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += querySum(1, n, pos[top[u]], pos[u], rt[x]);
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
return res + querySum(1, n, pos[u], pos[v], rt[x]);
}
int pathMax(int u, int v) {
int res = 0, x = col[u];
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res = max(res, queryMax(1, n, pos[top[u]], pos[u], rt[x]));
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
return max(res, queryMax(1, n, pos[u], pos[v], rt[x]));
}
int main() {
int i, x, y; n = read(); Q = read(); char c1, c2;
for (i = 1; i <= n; i++) val[i] = read(), col[i] = read();
for (i = 1; i < n; i++) x = read(), y = read(), add_edge(x, y); init();
while (Q--) {
c1 = get(); c2 = get(); x = read(); y = read();
if (c1 == 'C') {
if (c2 == 'C') changeC(x, y);
else change(1, n, pos[x], y, rt[col[x]]), val[x] = y;
}
else {
if (c2 == 'S') printf("%d\n", pathSum(x, y));
else printf("%d\n", pathMax(x, y));
}
}
return 0;
}