Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
Solution
树上路径问题,首先树链剖分,然后线段树中元素lc和rc表示区间左右端点的颜色,num表示区间颜色段数量,flag表示是否被覆盖(如果被覆盖flag记录被覆盖的颜色),每次C a b c操作就是区间更新[a,lca(a,b)]和[b,lca(a,b)],Q a b就是区间查询[a,lca(a,b)]和[b,lca(a,b)]的颜色段数量然后减一( lca(a,b)被算了两遍),注意每次区间合并时需要判断左区间的rc和右区间的lc是否相同,如果相同那么总区间的颜色段数量等于左右区间颜色段数量之和减一
Code
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define maxn 111111
struct Edge
{
int to,next;
}E[2*maxn];
struct Tree
{
int left,right,lc,rc,num,flag;
}T[4*maxn];
//lc,rc记录区间端点即left和right的颜色
//num记录区间[left,right]中颜色段数量
//flag记录区间是否被某种颜色全覆盖
int n,q,c[maxn],head[maxn],cnt,idx,size[maxn],fa[maxn],son[maxn],dep[maxn],top[maxn],id[maxn];
void init()
{
cnt=idx=0;
memset(head,-1,sizeof(head));
dep[1]=fa[1]=size[0]=0;
memset(son,0,sizeof(son));
}
void add(int u,int v)
{
E[cnt].to=v;
E[cnt].next=head[u];
head[u]=cnt++;
}
void dfs1(int u)
{
size[u]=1;
for(int i=head[u];~i;i=E[i].next)
{
int v=E[i].to;
if(v!=fa[u])
{
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
size[u]+=size[v];
if(size[son[u]]<size[v]) son[u]=v;
}
}
}
void dfs2(int u,int topu)
{
top[u]=topu;
id[u]=++idx;
if(son[u]) dfs2(son[u],top[u]);
for(int i=head[u];~i;i=E[i].next)
{
int v=E[i].to;
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
}
void push_up(int t)
{
T[t].lc=T[2*t].lc;
T[t].rc=T[2*t+1].rc;
T[t].num=T[2*t].num+T[2*t+1].num;
if(T[2*t].rc==T[2*t+1].lc) T[t].num--;//左区间右端点颜色等于右区间左端点颜色时总颜色段数量减一
}
void push_down(int t)
{
if(T[t].left==T[t].right||T[t].flag==-1)
{
T[t].flag=-1;
return ;
}
T[2*t].lc=T[2*t].rc=T[2*t+1].lc=T[2*t+1].rc=T[t].flag;
T[2*t].flag=T[2*t+1].flag=T[t].flag;
T[2*t].num=T[2*t+1].num=1;
T[t].flag=-1;
}
void build(int l,int r,int t)//建树
{
T[t].left=l;
T[t].right=r;
T[t].num=1;
T[t].flag=-1;
if(l==r) return ;
int mid=(l+r)>>1;
build(l,mid,2*t);
build(mid+1,r,2*t+1);
}
void update(int l,int r,int z,int t)//区间更新[l,r]的颜色为t
{
push_down(t);
if(T[t].left==l&&T[t].right==r)
{
T[t].lc=T[t].rc=T[t].flag=z;
T[t].num=1;
return ;
}
if(r<=T[2*t].right) update(l,r,z,2*t);
else if(l>=T[2*t+1].left) update(l,r,z,2*t+1);
else
{
update(l,T[2*t].right,z,2*t);
update(T[2*t+1].left,r,z,2*t+1);
}
push_up(t);
}
int modify(int l,int r,int t)//区间查询[l,r]上的颜色段数量
{
push_down(t);
if(T[t].left==l&&T[t].right==r) return T[t].num;
if(r<=T[2*t].right) return modify(l,r,2*t);
else if(l>=T[2*t+1].left) return modify(l,r,2*t+1);
else
{
int ans=T[2*t].rc==T[2*t+1].lc;//左区间右端点颜色等于右区间左端点颜色时总颜色段数量减一
return modify(l,T[2*t].right,2*t)+modify(T[2*t+1].left,r,2*t+1)-ans;
}
}
int get_color(int x,int t)//单点查询树上某节点颜色
{
if(T[t].left==x&&T[t].right==x) return T[t].lc;
push_down(t);
if(x<=T[2*t].right) return get_color(x,2*t);
return get_color(x,2*t+1);
}
int lca(int u,int v)//求u节点和v节点的lca
{
int top1=top[u],top2=top[v];
while(top1!=top2)
{
if(dep[top1]<dep[top2])
{
swap(top1,top2);
swap(u,v);
}
u=fa[top1];
top1=top[u];
}
return dep[u]<dep[v]?u:v;
}
void change(int f,int u,int z)//修改树上f节点到u节点之间的颜色为z
{
int top1=top[u],top2=top[f];
while(top1!=top2)
{
update(id[top1],id[u],z,1);
u=fa[top1];
top1=top[u];
}
update(id[f],id[u],z,1);
}
int query(int f,int u)//查询树上f节点到u节点之间的颜色段数量
{
int top1=top[u],top2=top[f],ans=0;
while(top1!=top2)
{
ans+=modify(id[top1],id[u],1);
if(get_color(id[top1],1)==get_color(id[fa[top1]],1)) ans--;//左区间右端点颜色等于右区间左端点颜色时总颜色段数量减一
u=fa[top1];
top1=top[u];
}
ans+=modify(id[f],id[u],1);
return ans;
}
int main()
{
while(~scanf("%d%d",&n,&q))
{
init();//初始化
int u,v,color;char op[11];
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs1(1);
dfs2(1,1);
build(1,n,1);//建树
for(int i=1;i<=n;i++) update(id[i],id[i],c[i],1);//单点更新就是左右端点相同的区间更新
while(q--)
{
scanf("%s",op);
if(op[0]=='C')
{
scanf("%d%d%d",&u,&v,&color);
int x=lca(u,v);
change(x,u,color);
change(x,v,color);
}
else
{
scanf("%d%d",&u,&v);
int x=lca(u,v);
printf("%d\n",query(x,u)+query(x,v)-1);
}
}
}
return 0;
}