2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 8230 Solved: 3073
[ Submit][ Status][ Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
树链剖分,用线段树维护每段上的颜色数量以及左右端点上的颜色,合并时看左右端点颜色是否相同,不同左右相加否则减一
然后区间修改
注意跳轻链的时候要特判一下链头和跳过去的节点颜色是否相同,如果相同要把答案减一。判的时候用线段树判,因为这是修改过后的线段树
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#include<cmath>
#define maxn 101000
#define inf 0x3f3f3f3f
#define ls p << 1
#define rs p << 1 | 1
using namespace std;
int read()
{
char ch = getchar(); int x = 0, f = 1;
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
struct tree {
int l, r, cnt, lc, rc;
bool tag;
tree() : l(0), r(0), cnt(0), lc(-1), rc(-1), tag(false) {}
}t[maxn * 10];
int pre[maxn], top;
struct edge {
int to, next;
void add(int a, int b) {
to = b;
next = pre[a];
pre[a] = top++;
}
}e[maxn * 2];
void adds(int u, int v)
{
e[top].add(u, v);
e[top].add(v, u);
}
int n, m, tot, sz;
int w[maxn], f[maxn], d[maxn], pos[maxn], deep[maxn], son[maxn], fw[maxn];
int lca(int u, int v)
{
while(d[u] != d[v])
{
if(deep[d[u]] < deep[d[v]]) swap(u, v);
u = f[d[u]];
}
if(deep[u] > deep[v]) swap(u, v);
return u;
}
void dfs1(int u, int fa)
{
son[u] = 1; f[u] = fa; deep[u] = deep[fa] + 1;
for(int i = pre[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if(v == fa) continue;
dfs1(v, u);
son[u] += son[v];
}
}
void dfs2(int u, int chain)
{
pos[u] = ++tot; d[u] = chain; int k = 0;
for(int i = pre[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if(v == f[u]) continue;
if(son[v] > son[k]) k = v;
}
if(!k) return;
dfs2(k, chain);
for(int i = pre[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if(v != f[u] && v != k)
dfs2(v, v);
}
}
void update(int p) {
t[p].cnt = t[ls].cnt + t[rs].cnt;
if(t[ls].rc == t[rs].lc) --t[p].cnt;
t[p].lc = t[ls].lc;
t[p].rc = t[rs].rc;
}
void build_tree(int p, int L, int R)
{
t[p].l = L; t[p].r = R; t[p].lc = w[L]; t[p].rc = w[R];
if(L == R) {
t[p].cnt = 1;
return;
}
int mid = L + R >> 1;
build_tree(ls, L, mid);
build_tree(rs, mid + 1, R);
update(p);
}
void init() {
n = read(); m = read(); memset(pre, -1, sizeof(pre)); tot = top = 0;
for(int i = 1;i <= n; ++i) fw[i] = read();
for(int i = 1;i < n; ++i) adds(read(), read());
dfs1(1, 0);
dfs2(1, 1);
for(int i = 1;i <= n; ++i) w[pos[i]] = fw[i];
build_tree(1, 1, n);
}
void paint(int p, int val)
{
t[p].lc = t[p].rc = val;
t[p].cnt = 1; t[p].tag = true;
}
void pushdown(int p) {
if(t[p].tag) {
paint(ls, t[p].lc);
paint(rs, t[p].lc);
t[p].tag = false;
}
}
void Seg_ch(int p, int st, int ed, int val) {
int l = t[p].l, r = t[p].r;
if(st == l && ed == r) {
paint(p, val);
return;
}
int mid = l + r >> 1;
pushdown(p);
if(st <= mid) Seg_ch(ls, st, min(mid, ed), val);
if(ed > mid) Seg_ch(rs, max(st, mid + 1), ed, val);
update(p);
}
void change(int u, int v, int val) {
while(d[u] != d[v]) {
if(deep[d[u]] < deep[d[v]]) swap(u, v);
Seg_ch(1, pos[d[u]], pos[u], val);
u = f[d[u]];
}
if(pos[u] > pos[v]) swap(u, v);
Seg_ch(1, pos[u], pos[v], val);
}
int Seg_sum(int p, int st, int ed) {
int l = t[p].l, r = t[p].r;
//cout<<p<<" "<<t[p].cnt<<" "<<st<<" "<<ed<<" "<<l<<" "<<r<<" "<<endl;
if(st == l && ed == r) return t[p].cnt;
int mid = l + r >> 1;
pushdown(p);
if(st > mid) return Seg_sum(rs, st, ed);
if(ed <= mid) return Seg_sum(ls, st, ed);
int ans = Seg_sum(ls, st, mid) + Seg_sum(rs, mid + 1, ed);
if(t[ls].rc == t[rs].lc) --ans;
return ans;
}
int Seg_col(int p, int pos) {
int l = t[p].l, r = t[p].r;
if(l == r) return t[p].lc;
pushdown(p);
int mid = l + r >> 1;
if(pos <= mid) return Seg_col(ls, pos);
else return Seg_col(rs, pos);
}
int query(int u, int v)
{
int ans = 0;
while(d[u] != d[v]) {
if(deep[d[u]] < deep[d[v]]) swap(u, v);
ans += Seg_sum(1, pos[d[u]], pos[u]);
if(Seg_col(1, pos[d[u]]) == Seg_col(1, pos[f[d[u]]])) ans--;
u = f[d[u]];
}
if(pos[u] > pos[v]) swap(u, v);
//cout<<pos[u]<<" "<<pos[v]<<endl;
ans += Seg_sum(1, pos[u], pos[v]);
return ans;
}
void solve() {
for(int i = 1;i <= m; ++i) {
char ch = getchar();
while(ch != 'Q' && ch != 'C') ch = getchar();
if(ch == 'C') {
int u = read(), v = read(), val = read();
change(u, v, val);
}
else {
int u = read(), v = read();
printf("%d\n", query(u, v));
}
}
}
int main()
{
init();
solve();
return 0;
}
/*
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
C 4 5 3
Q 5 3
C 4 6 5
Q 2 6
Q 4 3
*/