为了练手速我花了半个小时打完了这道题。。然后debug的时候就。。23333
首先如果是一个序列显然可以用线段树区间修改,维护段中的颜色数量,左右端点的颜色来做吧。
树上也一样,我们可以把树上的区间转化为dfs序列中若干个连续区间,然后用树链剖分使区间的个数<logN,注意一下端点的问题就好了(说白了就是一道树链剖分裸题)。
AC代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 100005
using namespace std;
int n,m,dfsclk,bin[25],a[N],b[N],d[N],fa[N][17],sz[N],son[N],anc[N],pos[N];
int tot,fst[N],pnt[N<<1],nxt[N<<1],cvr[N<<2];
struct node{ int l,r,sum; }val[N<<2];
int read(){
int x=0; char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x;
}
void add(int x,int y){
pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot;
}
void dfs(int x){
sz[x]=1; int p,i;
for (i=1; bin[i]<=d[x]; i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x][0]){
fa[y][0]=x; d[y]=d[x]+1;
dfs(y); sz[x]+=sz[y];
if (sz[y]>sz[son[x]]) son[x]=y;
}
}
}
void divide(int x,int tp){
pos[x]=++dfsclk; anc[x]=tp; int p;
if (son[x]) divide(son[x],tp);
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x][0] && y!=son[x]) divide(y,y);
}
}
int lca(int x,int y){
if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i;
for (i=0; bin[i]<=tmp; i++)
if (tmp&bin[i]) x=fa[x][i];
for (i=16; i>=0; i--)
if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; }
return (x==y)?x:fa[x][0];
}
void maintain(int k){
int l=k<<1,r=l|1;
val[k].sum=val[l].sum+val[r].sum; if (val[l].r==val[r].l) val[k].sum--;
val[k].l=val[l].l; val[k].r=val[r].r;
}
void chg(int k,int v){
val[k].sum=1; val[k].l=val[k].r=cvr[k]=v;
}
void pushdown(int k){
if (cvr[k]!=-1){
chg(k<<1,cvr[k]); chg(k<<1|1,cvr[k]); cvr[k]=-1;
}
}
void build(int k,int l,int r){
cvr[k]=-1;
if (l==r){
val[k].l=val[k].r=a[l]; val[k].sum=1;
return;
}
int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r);
maintain(k);
}
void ins(int k,int l,int r,int x,int y,int z){
if (l==x && r==y){ chg(k,z); return; }
int mid=(l+r)>>1;
pushdown(k);
if (y<=mid) ins(k<<1,l,mid,x,y,z); else
if (x>mid) ins(k<<1|1,mid+1,r,x,y,z); else{
ins(k<<1,l,mid,x,mid,z); ins(k<<1|1,mid+1,r,mid+1,y,z);
}
maintain(k);
}
node qry(int k,int l,int r,int x,int y){
if (cvr[k]!=-1){
node t1; t1.sum=1; t1.l=t1.r=cvr[k]; return t1;
}
if (l==x && r==y) return val[k]; int mid=(l+r)>>1;
if (y<=mid) return qry(k<<1,l,mid,x,y); else
if (x>mid) return qry(k<<1|1,mid+1,r,x,y); else{
node t1=qry(k<<1,l,mid,x,mid),t2=qry(k<<1|1,mid+1,r,mid+1,y);
t1.sum+=t2.sum; if (t1.r==t2.l) t1.sum--; t1.r=t2.r;
return t1;
}
}
void mdy(int x,int y,int z){
for (; anc[x]!=anc[y]; x=fa[anc[x]][0])
ins(1,1,n,pos[anc[x]],pos[x],z);
ins(1,1,n,pos[y],pos[x],z);
}
node solve(int x,int y){
if (anc[x]==anc[y]) return qry(1,1,n,pos[y],pos[x]);
node t1=qry(1,1,n,pos[anc[x]],pos[x]),t2;
x=fa[anc[x]][0];
for (; anc[x]!=anc[y]; x=fa[anc[x]][0]){
t2=qry(1,1,n,pos[anc[x]],pos[x]);
t1.sum+=t2.sum; if (t1.l==t2.r) t1.sum--; t1.l=t2.l;
}
t2=qry(1,1,n,pos[y],pos[x]);
t1.sum+=t2.sum; if (t1.l==t2.r) t1.sum--; t1.l=t2.l;
return t1;
}
int main(){
n=read(); m=read(); int i;
bin[0]=1; for (i=1; i<=17; i++) bin[i]=bin[i-1]<<1;
for (i=1; i<=n; i++) b[i]=read();
for (i=1; i<n; i++){
int x=read(),y=read();
add(x,y); add(y,x);
}
dfs(1); divide(1,1);
for (i=1; i<=n; i++) a[pos[i]]=b[i]; build(1,1,n);
char ch;
while (m--){
ch=getchar(); while (ch<'A' || ch>'Z') ch=getchar();
if (ch=='C'){
int x=read(),y=read(),z=read(),tmp=lca(x,y);
mdy(x,tmp,z); mdy(y,tmp,z);
} else{
int x=read(),y=read(),tmp=lca(x,y);
node t1=solve(x,tmp),t2=solve(y,tmp);
printf("%d\n",t1.sum+t2.sum-1);
}
}
return 0;
}
by lych
2016.3.8