题意:
中文题意不解释了。
解析:
用树链剖分求解该题。
在线性序列用线段树里面怎么维护颜色段。其实很简单,每个线段树的节点位置记录三个值:
lc:这段区间最左端的颜色;rc:这段区间最右端的颜色;sumv:这段区间里面的颜色数。
那么进行区间合并的时候,就可以:
sumv[o]=sumv[ls]+sumv[rs]−(rc[ls]==lc[rs]);还有个问题就是,维护路径时要注意的是。
统计答案的时候要记录下上一次剖到的链的左端点的颜色,与当前剖到的链右端点的颜色(因为在处理出的线段树中越靠近根的点位置越左),比较这两个颜色,若相同则答案减1。
my code
#include<stdio.h>
#include<string.h>
#include<vector>
#include<algorithm>
#define ls o<<1
#define rs o<<1|1
#define lson ls,L,M
#define rson rs,M+1,R
using namespace std;
const int N = 100005;
vector<int>G[N];
int color[N], fa[N], son[N], size[N], top[N], deep[N];
int p[N], fp[N], dfs_clock;
int n;
void dfs1(int u, int pre, int de) {
deep[u] = de, fa[u] = pre, size[u] = 1;
son[u] = 0;
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == pre) continue;
dfs1(v, u, de + 1);
size[u] += size[v];
if (size[v] > size[son[u]])
son[u] = v;
}
}
void dfs2(int u, int tp) {
top[u] = tp;
p[u] = ++dfs_clock;
fp[dfs_clock] = u;
if(son[u]) dfs2(son[u], tp);
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v != fa[u] && v != son[u])
dfs2(v, v);
}
}
void init_chain(int u) {
dfs_clock = 0;
dfs1(u, 0, 1);
dfs2(u, u);
}
int setv[N << 2], lc[N << 2], rc[N << 2], sumv[N << 2];
bool cover[N << 2];
void pushUp(int o) {
lc[o] = lc[ls], rc[o] = rc[rs];
sumv[o] = sumv[ls] + sumv[rs] - (rc[ls] == lc[rs]);
if(setv[ls] == setv[rs])
setv[o] = setv[ls];
}
void pushDown(int o) {
if (cover[o]) {
sumv[ls] = sumv[rs] = 1;
cover[ls] = cover[rs] = true;
setv[ls] = setv[rs] = setv[o];
lc[ls] = rc[ls] = lc[rs] = rc[rs] = setv[o];
setv[o] = cover[o] = 0;
}
}
void build(int o, int L, int R) {
if (L == R) {
lc[o] = rc[o] = setv[o] = color[fp[L]];
sumv[o] = 1;
return ;
}
int M = (L + R) >> 1;
build(lson);
build(rson);
pushUp(o);
}
void update(int o, int L, int R, int ql, int qr, int val) {
if (ql <= L && R <= qr) {
lc[o] = rc[o] = setv[o] = val;
sumv[o] = 1;
cover[o] = true;
return ;
}
int M = (L + R) >> 1;
pushDown(o);
if (ql <= M) update(lson, ql, qr, val);
if (qr > M) update(rson, ql, qr, val);
pushUp(o);
}
int get(int o, int L, int R, int pos) {
if (L == R) return lc[o];
int M = (L + R) >> 1;
pushDown(o);
if (pos <= M) return get(lson, pos);
else return get(rson, pos);
}
int query(int o, int L, int R, int ql, int qr) {
if (ql <= L && R <= qr) return sumv[o];
int M = (L + R) >> 1;
pushDown(o);
if (qr <= M) return query(lson, ql, qr);
else if (ql > M) return query(rson, ql, qr);
else {
int ret = query(lson, ql, qr) + query(rson, ql, qr);
if (rc[ls] == lc[rs]) ret--;
return ret;
}
}
void change(int u, int v, int val) {
while (top[u] != top[v]) {
if (deep[top[u]] < deep[top[v]])
swap(u, v);
update(1, 1, n, p[top[u]], p[u], val);
u = fa[top[u]];
}
if (deep[u] < deep[v]) swap(u, v);
update(1, 1, n, p[v], p[u], val);
}
int queryPath(int u, int v) {
int ret = 0;
while (top[u] != top[v]) {
if (deep[top[u]] < deep[top[v]]) swap(u, v);
ret += query(1, 1, n, p[top[u]], p[u]);
if (get(1, 1, n, p[top[u]]) == get(1, 1, n, p[fa[top[u]]]))
ret--;
u = fa[top[u]];
}
if (deep[u] < deep[v]) swap(u, v);
ret += query(1, 1, n, p[v], p[u]);
return ret;
}
int main() {
int m, k, u, v, val;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &color[i]);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
init_chain(1);
build(1, 1, n);
char str[10];
while(m--) {
scanf("%s", str);
if (str[0] == 'C') {
scanf("%d%d%d", &u, &v, &val);
change(u, v, val);
}else {
scanf("%d%d", &u, &v);
printf("%d\n", queryPath(u, v));
}
}
return 0;
}