感觉好久没有写题解了。
对于这道题我们有个思路:对于线段树上的每个点,记录此点表示的区间l和r的颜色。在区间合并的时候就判断左儿子的r与右儿子的l是否相同,相同就将sum减1。
你要开始码了吗?先别急,我们用的是树链剖分,是将查询的问题拆分成几段的,也就是说不能在查询函数里直接完成。我们要提出来分别加。
先开始我想的是返回一个结构体,再比较,但其实不用这样。
画个图:
比如查询4,5。
S
t
e
p
1
Step1
Step1:算5,5。答案为1。
S
t
e
p
2
Step2
Step2:将5的top与5的top的爸爸比较,相同,答案为0。
S
t
e
p
3
Step3
Step3:计算2到4,答案为2。
其实就是判断一下top与topba的颜色是否相同。
详见代码:
#include<cstdio>
const int N = 1 * 1e5 + 2;//op数组就是id数组对应原编号
int la[N << 2], op[N], val[N], m, son[N], num, n, d[N], siz[N], tp[N], id[N], cnt, f[N], head[N], dot[N << 1], nxt[N << 1];
struct tree {
int sum, l, r;
}st[N << 2];
int Max(const int a, const int b) {
if(a > b) return a;
return b;
}
int read() {
int x = 0, f = 1;
char s = getchar();
while(s > '9' || s < '0') {
if(s == '-') f = -1;
s = getchar();
}
while(s >= '0' && s <= '9') {
x = (x << 1) + (x << 3) + (s ^ 48);
s = getchar();
}
return x * f;
}
void addEdge(const int a, const int b) {
dot[++ num] = b;
nxt[num] = head[a];
head[a] = num;
}
void init(const int u, const int ba) {
f[u] = ba;
siz[u] = 1;
d[u] = d[ba] + 1;
for(int i = head[u]; i; i = nxt[i]) {
int v = dot[i];
if(v == ba) continue;
init(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]] || son[u] == 0)
son[u] = v;
}
}
void dfs(const int u, const int t) {
tp[u] = t;
id[u] = ++ cnt;
if(! son[u]) return;
dfs(son[u], t);
for(int i = head[u]; i; i = nxt[i]) {
int v = dot[i];
if(v == f[u] || v == son[u]) continue;
dfs(v, v);
}
}
void pushUp(const int o, const int l, const int r) {
st[o].sum = st[o << 1].sum + st[o << 1 | 1].sum;
if(st[o << 1].r == st[o << 1 | 1].l) -- st[o].sum;
st[o].l = st[o << 1].l;
st[o].r = st[o << 1 | 1].r;
}
void pushDown(const int o, const int l, const int r) {
if(! la[o]) return;
int mid = l + r >> 1;
la[o << 1] = la[o << 1 | 1] = la[o];
st[o << 1].sum = st[o << 1 | 1].sum = 1;
st[o << 1].l = st[o << 1].r = st[o << 1 | 1].l = st[o << 1 | 1].r = la[o];
if(l == mid) val[op[l]] = la[o];
if(r == mid + 1) val[op[r]] = la[o];
la[o] = 0;
}
void add(const int o, const int l, const int r, const int L, const int R, const int k) {
if(l > R || r < L) return;
if(l >= L && r <= R) {
st[o].sum = 1;
st[o].l = st[o].r = la[o] = k;//这个地方要赋值l与r!!!
if(l == r) val[op[l]] = k;
return;
}
pushDown(o, l, r);
int mid = l + r >> 1;
add(o << 1, l, mid, L, R, k);
add(o << 1 | 1, mid + 1, r, L, R, k);
pushUp(o, l, r);
}
int ask(const int o, const int l, const int r, const int L, const int R) {
if(l > R || r < L) return 0;
if(l >= L && r <= R) return st[o].sum;
pushDown(o, l, r);
int mid = l + r >> 1;
int ans1 = ask(o << 1, l, mid, L, R), ans2 = ask(o << 1 | 1, mid + 1, r, L, R);
if(st[o << 1].r && st[o << 1].r == st[o << 1 | 1].l && ans1 && ans2) -- ans1;//这里注意判0
return ans1 + ans2;
}
int query(const int o, const int l, const int r, const int goal) {
if(l > goal || r < goal) return 0;
if(l == r) return val[op[l]];
pushDown(o, l, r);
int mid = l + r >> 1;
return query(o << 1, l, mid, goal) + query(o << 1 | 1, mid + 1, r, goal);
}
int main() {
char ch[5];
int a, b, c, res;
n = read(), m = read();
for(int i = 1; i <= n; ++ i) val[i] = read();
for(int i = 1; i < n; ++ i) {
a = read(), b = read();
addEdge(a, b);
addEdge(b, a);
}
init(1, 0);
dfs(1, 1);
for(int i = 1; i <= n; ++ i) add(1, 1, n, id[i], id[i], val[i]);
for(int i = 1; i <= n; ++ i) op[id[i]] = i;
while(m --) {
scanf("%s", ch);
a = read(), b = read();
if(ch[0] == 'Q') {
res = 0;
while(tp[a] != tp[b] && tp[a] && tp[b]) {
if(d[tp[a]] > d[tp[b]]) {
res += ask(1, 1, n, id[tp[a]], id[a]);
if(query(1, 1, n, id[tp[a]]) == query(1, 1, n, id[f[tp[a]]])) -- res;
a = f[tp[a]];
}
else {
res += ask(1, 1, n, id[tp[b]], id[b]);
if(query(1, 1, n, id[tp[b]]) == query(1, 1, n, id[f[tp[b]]])) -- res;
b = f[tp[b]];
}
}
if(id[a] < id[b]) res += ask(1, 1, n, id[a], id[b]);
else res += ask(1, 1, n, id[b], id[a]);
printf("%d\n", res);
}
else {
c = read();
while(tp[a] != tp[b] && tp[a] && tp[b]) {
if(d[tp[a]] > d[tp[b]]) {
add(1, 1, n, id[tp[a]], id[a], c);
a = f[tp[a]];
}
else {
add(1, 1, n, id[tp[b]], id[b], c);
b = f[tp[b]];
}
}
if(id[a] < id[b]) add(1, 1, n, id[a], id[b], c);
else add(1, 1, n, id[b], id[a], c);
}
}
return 0;
}
//附赠数据(虽然并没有什么bi用)
/*
6 100
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
C 1 4 3
C 6 6 3
Q 6 3
*/
最后,据热心市民 提供,这题用vector会T。