题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=2243
[SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MB
Submit: 6131 Solved: 2243
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
思路:基于树链剖分的一道线段树的区间合并简单题。需要维护的信息有三个:节点最左边的值,节点最右边的值,节点的答案值。最左最右值在区间合并时有用,对于每一次区间衔接时,判断一下端点值是否相同即可。详见代码。
附上AC代码:
#include <bits/stdc++.h>
#define lrt rt<<1
#define rrt rt<<1|1
#define lson l, m, lrt
#define rson m+1, r, rrt
using namespace std;
const int maxn = 100005;
// 分别表示以当前节点作为根的子树的节点数目,
// 树上各个节点的初始值,当前节点的重儿子
int sizev[maxn], num[maxn], son[maxn];
// 分别表示树链上深度最小的节点,当前节点的深度,
// 原节点在剖分后的时间戳,即新的编号
int top[maxn], deep[maxn], pos[maxn];
// 分别表示当前时间戳对应的原节点编号,
// 当前节点的父节点
int level[maxn], p[maxn];
// 判断该节点是否被访问过了
bool vis[maxn];
// 分别表示节点数,询问数和时间戳计数
int n, q, cnt;
// 存储树的各个节点所连接的边
vector<int> edge[maxn];
void init(){
for (int i=1; i<=n; ++i){
sizev[i] = top[i] = son[i] = 0;
deep[i] = pos[i] = level[i] = 0;
p[i]=0, vis[i]=false;
cnt = 0;
edge[i].clear();
}
}
void add_edge(int u, int v){
edge[u].push_back(v);
edge[v].push_back(u);
}
void dfs1(int u, int root){
vis[u] = true;
sizev[u] = 1;
p[u] = root;
deep[u] = deep[root]+1;
int siz = edge[u].size();
for (int i=0; i<siz; ++i){
int v = edge[u][i];
if (v!=p[u] && !vis[v]){
dfs1(v, u);
sizev[u] += sizev[v];
if (son[u] == 0)
son[u] = v;
else if (sizev[son[u]] < sizev[v])
son[u] = v;
}
}
}
void dfs2(int u, int root){
vis[u] = true;
pos[u] = ++cnt;
level[cnt] = u;
top[u] = root;
if (son[u])
dfs2(son[u], root);
int siz = edge[u].size();
for (int i=0; i<siz; ++i){
int v = edge[u][i];
if (v!=p[u] && v!=son[u] && !vis[v])
dfs2(v, v);
}
}
int sumv[maxn<<2], setv[maxn<<2];
int lv[maxn<<2], rv[maxn<<2];
char op[5];
void push_up(int rt){
lv[rt]=lv[lrt], rv[rt]=rv[rrt];
sumv[rt] = sumv[lrt]+sumv[rrt]-(rv[lrt]==lv[rrt]);
}
void build(int l, int r, int rt){
setv[rt] = -1;
if (l == r){
sumv[rt] = 1;
lv[rt] = rv[rt] = num[level[l]];
return ;
}
int m = (l+r)>>1;
build(lson);
build(rson);
push_up(rt);
}
void push_down(int l, int r, int rt){
if (setv[rt] != -1){
setv[lrt] = setv[rrt] = setv[rt];
sumv[lrt] = sumv[rrt] = 1;
lv[lrt] = lv[rrt] = setv[rt];
rv[lrt] = rv[rrt] = setv[rt];
setv[rt] = -1;
}
}
void update(int cl, int cr, int val, int l, int r, int rt){
if (cl<=l && r<=cr){
setv[rt] = lv[rt] = rv[rt] = val;
sumv[rt] = 1;
return ;
}
push_down(l, r, rt);
int m = (l+r)>>1;
if (cl <= m)
update(cl, cr, val, lson);
if (cr > m)
update(cl, cr, val, rson);
push_up(rt);
}
int queries(int ql, int qr, int l, int r, int rt){
if (ql<=l && r<=qr)
return sumv[rt];
push_down(l, r, rt);
int m = (l+r)>>1;
int sumr = 0;
if (ql <= m)
sumr += queries(ql, qr, lson);
if (qr > m)
sumr += queries(ql, qr, rson);
if (ql<=m && qr>m && rv[lrt]==lv[rrt])
--sumr;
return sumr;
}
void change(int x, int y, int val){
while (top[x] != top[y]){
if (deep[top[x]] < deep[top[y]])
swap(x, y);
update(pos[top[x]], pos[x], val, 1, n, 1);
x = p[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
update(pos[x], pos[y], val, 1, n, 1);
}
int get_find(int pos, int l, int r, int rt){
if (l == r)
return lv[rt];
push_down(l, r, rt);
int m = (l+r)>>1;
if (pos <= m)
return get_find(pos, lson);
return get_find(pos, rson);
}
int seek(int x, int y){
int ans = 0;
while (top[x] != top[y]){
if (deep[top[x]] < deep[top[y]])
swap(x, y);
ans += queries(pos[top[x]], pos[x], 1, n, 1);
int t1 = get_find(pos[top[x]], 1, n, 1);
int t2 = get_find(pos[p[top[x]]], 1, n, 1);
if (t1 == t2)
--ans;
x = p[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
return ans+queries(pos[x], pos[y], 1, n, 1);
}
int main(){
while (~scanf("%d%d", &n, &q)){
init();
for (int i=1; i<=n; ++i)
scanf("%d", num+i);
int a, b, c;
for (int i=1; i<n; ++i){
scanf("%d%d", &a, &b);
add_edge(a, b);
}
dfs1(1, 0);
memset(vis, false, sizeof(bool)*(n+1));
dfs2(1, 1);
build(1, n, 1);
while (q--){
scanf("%s%d%d", op, &a, &b);
if (op[0] == 'Q')
printf("%d\n", seek(a, b));
else{
scanf("%d", &c);
change(a, b, c);
}
}
}
return 0;
}