[SDOI2011]染色
describes
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
HINT
Time Limit: 20 Sec Memory Limit: 512 MB
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
题解:
题目初看很容易理解,就知道是线段树经典题目和树链剖分的结合,用线段树维护区间的sum,left color,right color,delta即可。update,和pushdown和线段树里的经典题一样不用说了。
但是很容易发现的问题就是由于线段树维护的是链剖之后的树,也就是说查询的区间可能不会相连。
这时可以发现可以通过原树中的top[x],和fa[top[x]]将查询的链联系起来。因为如果换重链,相邻的两个节点一定是top[x]和fa[top[x]],接下来就想在线段树上维护颜色一样维护就行了。
即:
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
const int max_n = 1e5+5;
struct node{
int l,r;
int lc,rc,sum,delta;
}tree[max_n*4];
int fa[max_n],deep[max_n],nxt[max_n*2],point[max_n*2],v[max_n*2],val[max_n],top[max_n],size[max_n],rank_n[max_n];
char c;
int n,m,tot,x,y,k,last,cha;
inline void clear()
{
memset(point,-1,sizeof(point));
memset(nxt,-1,sizeof(nxt));
tot=0;
}
inline void addedge(int x,int y)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
++tot; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
inline void update(int now)
{
tree[now].sum=tree[now<<1].sum+tree[(now<<1)+1].sum;
tree[now].lc=tree[now<<1].lc;
tree[now].rc=tree[(now<<1)+1].rc;
if(tree[now<<1].rc==tree[(now<<1)+1].lc)
tree[now].sum-=1;
}
inline void pushdown(int now)
{
if(tree[now].delta!=0)
{
tree[now<<1].sum=1;
tree[now<<1].delta=tree[now].delta;
tree[now<<1].lc=tree[now<<1].rc=tree[now].delta;
tree[(now<<1)+1].sum=1;
tree[(now<<1)+1].delta=tree[now].delta;
tree[(now<<1)+1].lc=tree[(now<<1)+1].rc=tree[now].delta;
tree[now].delta=0;
}
}
inline void build(int now,int l,int r)
{
tree[now].l=l;
tree[now].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(now<<1,l,mid);
build((now<<1)+1,mid+1,r);
}
inline void change(int now,int l,int r,int val)
{
int l1=tree[now].l;
int r1=tree[now].r;
if(l1>=l && r1<=r)
{
tree[now].sum=1;
tree[now].delta=val;
tree[now].lc=tree[now].rc=val;
return;
}
pushdown(now);
int mid=(l1+r1)>>1;
if(l<=mid)
change(now<<1,l,r,val);
if(r>mid)
change((now<<1)+1,l,r,val);
update(now);
}
inline void schange(int x,int y,int val)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
change(1,rank_n[top[x]],rank_n[x],val);
x=fa[top[x]];
}
if(rank_n[x]>rank_n[y]) swap(x,y);
change(1,rank_n[x],rank_n[y],val);
}
inline void dfs1(int now,int f)
{
size[now]=1;
deep[now]=deep[f]+1;
fa[now]=f;
for(int i=point[now]; i!=-1; i=nxt[i])
if(v[i]!=f)
{
dfs1(v[i],now);
size[now]+=size[v[i]];
}
}
inline void dfs2(int now,int tip)
{
top[now]=tip;
rank_n[now]=++tot;
change(1,tot,tot,val[now]);
if(now!=1 && nxt[point[now]]==-1)
return;
int mson=0;
for(int i=point[now]; i!=-1; i=nxt[i])
if(size[v[i]]<size[now] && size[mson]<size[v[i]]) mson=v[i];
dfs2(mson,tip);
for(int i=point[now]; i!=-1; i=nxt[i])
if(size[v[i]]<size[now] && v[i]!=mson)
dfs2(v[i],v[i]);
}
inline int query(int now,int l,int r)
{
int l1=tree[now].l;int r1=tree[now].r;
if(l1>=l && r1<=r) return tree[now].sum;
pushdown(now);
int ans=0,mid=(l1+r1)>>1;
if(l<=mid)
ans+=query(now<<1,l,r);
if(r>mid)
ans+=query((now<<1)+1,l,r);
if(l<=mid && r>mid)
if(tree[now<<1].rc==tree[(now<<1)+1].lc) ans--;//判断相邻两个色段交汇处有无同样的颜色
return ans;
}
inline int query_c(int now,int tar)//查询颜色
{
int l=tree[now].l;
int r=tree[now].r;
if(l==r && l==tar) return tree[now].lc;
pushdown(now);
int mid=(l+r)/2;
if(tar<=mid)
return query_c(now<<1,tar);
else
return query_c((now<<1)+1,tar);
}
inline int squery(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
ans+=query(1,rank_n[top[x]],rank_n[x]);
if(query_c(1,rank_n[fa[top[x]]])==query_c(1,rank_n[top[x]])) ans-=1;//和线段树不同点:换重链是判断相邻重链是否颜色一致
x=fa[top[x]];
}
if(rank_n[x]>rank_n[y]) swap(x,y);
ans+=query(1,rank_n[x],rank_n[y]);
return ans;
}
int main()
{
clear();
scanf("%d%d",&n,&m);
for(int i=1; i<=n; ++i)
scanf("%d",&val[i]);
for(int i=1; i<=n-1; ++i)
{
scanf("%d%d",&x,&y);
addedge(x,y);
}
tot=0;
build(1,1,n);
dfs1(1,0);
dfs2(1,1);
for(int i=1; i<=m; ++i)
{
cin>>c;
if(c=='Q')
{
scanf("%d%d",&x,&y);
printf("%d\n",squery(x,y));
}
else
{
scanf("%d%d%d",&x,&y,&k);
schange(x,y,k);
}
}
return 0;
}