P2486 [SDOI2011]染色
模型总结
树链剖分+线段树
关键点
- 懒标记不要忘记下传
#include<iostream>
#include<cstdio>
#define lc (c<<1)
#define rc (c<<1|1)
using namespace std;
const int inf=1e9;
int read(){
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return f*x;
}
const int maxn=100005;
int n,m;
int w[maxn];
int h[maxn],to[maxn<<1],nxt[maxn<<1],tot;
void ade(int x,int y){
to[++tot]=y; nxt[tot]=h[x]; h[x]=tot;
}
int fa[maxn],dep[maxn],siz[maxn],
son[maxn],newp[maxn],oldp[maxn],top[maxn];
void dfs1(int f,int u){
dep[u]=dep[f]+1;
siz[u]=1;
fa[u]=f;
for(int i=h[u];i;i=nxt[i]){
int v=to[i];
if(v==f) continue;
dfs1(u,v);
if(siz[son[u]]<siz[v]) son[u]=v;
siz[u]+=siz[v];
}
}
int dfn;
void dfs2(int f,int u){
newp[u]=++dfn;
oldp[dfn]=u;
if(!son[u]) return;
top[son[u]]=top[u];
dfs2(u,son[u]);
for(int i=h[u];i;i=nxt[i]){
int v=to[i];
if(v==f||v==son[u]) continue;
top[v]=v;
dfs2(u,v);
}
}
struct node{
int l,r,coll,colr,sum,tag;
node(){coll=colr=sum=tag=0;}
}p[maxn*4];
void upd(int c){
p[c].sum=p[lc].sum+p[rc].sum;
if(p[lc].colr==p[rc].coll) p[c].sum--;
p[c].coll=p[lc].coll;
p[c].colr=p[rc].colr;
}
void build(int c,int l,int r){
p[c].l=l; p[c].r=r;
if(l==r){
p[c].coll=p[c].colr=w[oldp[l]];
p[c].sum=1;
return;
}
int mid=(p[c].r+p[c].l)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
upd(c);
}
void chg(int c,int val){
p[c].coll=p[c].colr=val;
p[c].sum=1;
p[c].tag=val;
}
void pushdown(int c){
if(p[c].tag){
chg(lc,p[c].tag);
chg(rc,p[c].tag);
p[c].tag=0;
}
}
void chg(int c,int l,int r,int val){
if(p[c].l>=l&&p[c].r<=r){
chg(c,val);
return;
}
pushdown(c);
if(p[lc].r>=l) chg(lc,l,r,val);
if(p[rc].l<=r) chg(rc,l,r,val);
upd(c);
}
int qcol(int c,int pos){
if(p[c].l==p[c].r){
return p[c].coll;
}
pushdown(c);
if(p[lc].r>=pos) return qcol(lc,pos);
return qcol(rc,pos);
}
int qsum(int c,int l,int r){
if(p[c].l>=l&&p[c].r<=r){
return p[c].sum;
}
int ans=0;
pushdown(c);
if(p[lc].r>=l) ans+=qsum(lc,l,r);
if(p[rc].l<=r) ans+=qsum(rc,l,r);
if(p[lc].r>=l&&p[rc].l<=r&&p[lc].colr==p[rc].coll) ans--;
return ans;
}
int qsumonr(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=qsum(1,newp[top[x]],newp[x]);
if(qcol(1,newp[top[x]])==qcol(1,newp[fa[top[x]]])) ans--;
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
ans+=qsum(1,newp[y],newp[x]);
return ans;
}
void chgonr(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
chg(1,newp[top[x]],newp[x],val);
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
chg(1,newp[y],newp[x],val);
}
char ch[10];
int main(){
// freopen("c.in","r",stdin);
n=read(),m=read();
for(int i=1;i<=n;i++){
w[i]=read();
}
for(int i=1;i<n;i++){
int x=read(),y=read();
ade(x,y); ade(y,x);
}
top[1]=1;
dfs1(0,1);
dfs2(0,1);
build(1,1,n);
for(int i=1;i<=m;i++){
scanf("%s",ch);
int a=read(),b=read();
if(ch[0]=='C') chgonr(a,b,read());
else printf("%d\n",qsumonr(a,b));
}
return 0;
}