BZOJ2243: [SDOI2011]染色
题目描述
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
输入
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
输出
对于每个询问操作,输出一行答案。
Solution
树链剖分+线段树维护
重点在于线段树的合并
记录每个区间的左端点和右端点的颜色
当合并线段树的时候,如果左区间的右端点和右区间的左端点相同时,当前区间权值减一
树剖查询的时候也有地方要注意,代码里有一个非常巧妙的操作
然后调了好久…
调代码的问题在注释里…
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define lson o<<1,l,Mid
#define rson o<<1|1,Mid+1,r
#define MID Mid=(l+r)>>1
#define maxn 100010
int dep[maxn],pos[maxn],top[maxn],dfn[maxn],siz[maxn],fa[maxn],son[maxn];
int head[maxn],ecnt=0,w[maxn];
int tr[maxn<<2],laz[maxn<<2],L[maxn<<2],R[maxn<<2];
int n;
void Print2(int o,int l,int r){
if(l==r){
printf("%d ",laz[o]);
return ;
}
int MID;
Print2(lson);
Print2(rson);
return ;
}
inline int read(){
int ret=0,ff=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') ff=-ff;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
ret=ret*10+ch-'0';
ch=getchar();
}
return ret*ff;
}
struct Edge{
int u,v,next;
}E[maxn<<1];
void addedge(int u,int v){
E[++ecnt].u=u;
E[ecnt].v=v;
E[ecnt].next=head[u];
head[u]=ecnt;
}
void dfs1(int x,int fx){
dep[x]=dep[fx]+1;
siz[x]=1;
fa[x]=fx;
int mx=0;
for(int i=head[x];i;i=E[i].next){
int v=E[i].v;
if(v==fx) continue;
dfs1(v,x);
siz[x]+=siz[v];
if(mx<siz[v]) son[x]=v,mx=siz[v];
}
}
int ord=0;
void dfs2(int x,int tp){
dfn[x]=++ord;
pos[ord]=x;
top[x]=tp;
if(!son[x]) return ;
dfs2(son[x],tp);
for(int i=head[x];i;i=E[i].next){
int v=E[i].v;
if(v==fa[x]||v==son[x]) continue;
dfs2(v,v);
}
return ;
}
void pushup(int o){
L[o]=L[o<<1],R[o]=R[o<<1|1];
tr[o]=tr[o<<1]+tr[o<<1|1]-(L[o<<1|1]==R[o<<1]);
}
void pushdown(int o){
if(~laz[o]){
laz[o<<1]=laz[o<<1|1]=L[o<<1]=L[o<<1|1]=R[o<<1]=R[o<<1|1]=laz[o];
tr[o<<1]=tr[o<<1|1]=1;
laz[o]=-1;
}
}
void build(int o,int l,int r){
if(l==r){
tr[o]=1;
laz[o]=L[o]=R[o]=w[pos[l]];
//还有这里我写成了laz[o]=L[l]=R[l]=w[pos[l]];
return ;
}
int MID;
build(lson);
build(rson);
pushup(o);
}
void update(int o,int l,int r,int ql,int qr,int val){
if(ql<=l&&r<=qr){
tr[o]=1;
L[o]=R[o]=laz[o]=val;
//没错就是这里,我写成了L[o<<1|1]=R[o<<1]=laz[o]=val;
return ;
}
pushdown(o);
int MID;
if(ql<=Mid) update(lson,ql,qr,val);
if(qr>Mid) update(rson,ql,qr,val);
pushup(o);
return ;
}
int query(int o,int l,int r,int ql,int qr){
if(ql==l&&qr==r) return tr[o];
int MID;
pushdown(o);
if(Mid>=qr) return query(lson,ql,qr);
else if(Mid<ql) return query(rson,ql,qr);
else{
int tmp=1;
if(L[o<<1|1]^R[o<<1]) tmp=0;
return query(lson,ql,Mid)+query(rson,Mid+1,qr)-tmp;
}
}
void Update(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],val);
return ;
}
int qc(int o,int l,int r,int x){
if(l==r) return L[o];
int MID;
pushdown(o);
if(x<=Mid) return qc(lson,x);
else return qc(rson,x);
}
int Query(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=query(1,1,n,dfn[top[x]],dfn[x]);
int cc=qc(1,1,n,dfn[top[x]]),tt=qc(1,1,n,dfn[fa[top[x]]]);
ans-=(cc==tt);
//巧妙的地方就在这里
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(1,1,n,dfn[x],dfn[y]);
return ans;
}
int main(){
n=read();
int m=read();
for(int i=1;i<=n;++i) w[i]=read();
for(int i=1;i<n;++i){
int a=read(),b=read();
addedge(a,b);
addedge(b,a);
}
dfs1(1,0);
dfs2(1,0);
memset(laz,-1,sizeof(laz));
memset(L,-1,sizeof(L));
memset(R,-1,sizeof(R));
build(1,1,n);
while(m--){
char op[5];
scanf("%s",op);
if(op[0]=='C'){
int a=read(),b=read(),c=read();
Update(a,b,c);
}
else{
int a=read(),b=read();
printf("%d\n",Query(a,b));
}
}
return 0;
}