裸题嘛。。
先考虑一条线段上如何查询颜色段数,只要对每个线段树节点多维护一个左颜色和右颜色,然后合并的时候sum[x]=sum[lc]+sum[rc]-(左儿子的右颜色==右儿子的左颜色)。。实在太久没写树剖结果码+调试花了两节多晚自习,,各种傻逼错误,什么反向边忘加,标记忘记下传。。。还有就是更新答案的时候,关键的一点是要保证当前的两点(也就是a,b)是没有被更新到的,否则很难搞。。
表示LCT要更好写。。不过在BZOJ上我的树链剖分6000+MS,LCT要13000+MS。。
树链剖分:
#include<iostream>
#include<cstdio>
#include<memory.h>
#define maxn 100005
using namespace std;
struct edge{
int e,next;
} ed[maxn*2];
int n,m,i,cnt=0,ne=0,nd=0,s,e,ll,rr,cc,a[maxn],root[maxn],xu[maxn],belong[maxn],top[maxn],col[maxn],size[maxn],chain[maxn],last[maxn],pre[maxn],d[maxn];
int c[maxn*2][2],sum[maxn*2],l[maxn*2],r[maxn*2],lcol[maxn*2],rcol[maxn*2],tag[maxn*2];
char ch;
void add(int s,int e){ed[++ne].e=e;ed[ne].next=a[s];a[s]=ne;}
void dfs(int x)
{
int j,i=0,maxs=0,to;
size[x]=1;belong[x]=0;
for (j=a[x];j;j=ed[j].next)
if (ed[j].e!=pre[x])
{
to=ed[j].e;d[to]=d[x]+1;pre[to]=x;
dfs(to);size[x]+=size[to];
if (size[to]>maxs) maxs=size[to],i=to;
}
for (j=a[x];j;j=ed[j].next)
if (ed[j].e!=pre[x])
{
to=ed[j].e;
if (to==i) xu[x]=xu[to]+1,belong[x]=belong[to],last[x]=to,top[belong[to]]=x;
else top[belong[to]]=to;
}
if (!belong[x]) belong[x]=++cnt,xu[x]=1,last[x]=-1,top[cnt]=x;
}
void update(int x)
{
lcol[x]=lcol[c[x][0]];rcol[x]=rcol[c[x][1]];
sum[x]=sum[c[x][0]]+sum[c[x][1]]-(rcol[c[x][0]]==lcol[c[x][1]]);
}
void build(int &x,int ll,int rr)
{
x=++nd;l[x]=ll;r[x]=rr;tag[x]=-1;
if (ll==rr)
{
lcol[x]=rcol[x]=chain[ll];
sum[x]=1;
return;
}
build(c[x][0],ll,(ll+rr)/2);build(c[x][1],(ll+rr)/2+1,rr);
update(x);
}
void prepare()
{
int i,j;
pre[1]=0;d[1]=1;dfs(1);
for (i=1;i<=cnt;i++)
{
for (j=top[i];j!=-1;j=last[j]) chain[xu[j]]=col[j];
build(root[i],1,xu[top[i]]);
}
// for (i=1;i<=n;i++) printf("%d %d %d %d*\n",i,belong[i],root[belong[i]],d[i]);
}
void mark(int x,int cc)
{
if (!x) return;
tag[x]=cc;sum[x]=1;
lcol[x]=rcol[x]=cc;
}
void down(int x)
{
if (tag[x]!=-1) mark(c[x][0],tag[x]),mark(c[x][1],tag[x]),tag[x]=-1;
}
void ins(int x,int ll,int rr,int cc)
{
if (ll>r[x]||rr<l[x]) return;
if (ll<=l[x]&&rr>=r[x]) {mark(x,cc);return;}
down(x);
ins(c[x][0],ll,rr,cc);ins(c[x][1],ll,rr,cc);
update(x);
}
int query(int x,int ll,int rr)
{
if (ll>r[x]||rr<l[x]) return 0;
if (ll<=l[x]&&rr>=r[x]) return sum[x];
down(x);
int ans=query(c[x][0],ll,rr)+query(c[x][1],ll,rr),mid=(l[x]+r[x])/2;
// printf("%d %d %d %d %d###\n",ll,rr,l[x],r[x],ans);
if (ll<=mid&&rr>mid) ans-=(rcol[c[x][0]]==lcol[c[x][1]]);
return ans;
}
int getc(int x,int w)
{
int mid=(l[x]+r[x])/2;
down(x);
if (l[x]==r[x]) return lcol[x];
return w<=mid? getc(c[x][0],w):getc(c[x][1],w);
}
void change(int a,int b,int cc)
{
int ba,bb;
while (belong[a]!=belong[b])
{
ba=belong[a];bb=belong[b];
if (d[top[ba]]<d[top[bb]]) swap(a,b),swap(ba,bb);
ins(root[ba],xu[a],xu[top[ba]],cc);
a=top[ba];if (pre[a]) a=pre[a];
}
if (d[a]<d[b]) swap(a,b);
ins(root[belong[a]],xu[a],xu[b],cc);
}
int solve(int a,int b)
{
int ba,bb,ans=0,t,aaa=a,bbb=b;
bool f=false;
if (a==b) return 1;
while (belong[a]!=belong[b])
{
f=true;
ba=belong[a];bb=belong[b];
if (d[top[ba]]<d[top[bb]]) swap(a,b),swap(ba,bb),swap(aaa,bbb);
ans+=query(root[ba],xu[a],xu[top[ba]]);
a=top[ba];
if (pre[a]&&getc(root[ba],xu[a])==getc(root[belong[pre[a]]],xu[pre[a]])) ans--;
if (pre[a]) a=pre[a];//printf("%d %d %d**\n",a,b,ans);
}
if (d[a]<d[b]) swap(a,b),swap(aaa,bbb);
t=query(root[belong[a]],xu[a],xu[b]);
if (!f) return t;
// printf("%d %d***\n",t,ans);
return ans+t;
}
int main()
{
freopen("2243.in","r",stdin);
freopen("2243.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=1;i<=n;i++) scanf("%d",&col[i]);
for (i=1;i<n;i++)
{
scanf("%d%d",&s,&e);
add(s,e);add(e,s);
}
prepare();
scanf("\n");
for (i=1;i<=m;i++)
{
scanf("%c%d%d",&ch,&ll,&rr);
if (ch=='Q') printf("%d\n",solve(ll,rr));
else scanf("%d",&cc),change(ll,rr,cc);
scanf("\n");
}
fclose(stdin);
fclose(stdout);
}
LCT:
#include<iostream>
#include<cstdio>
#include<memory.h>
#define maxn 100005
using namespace std;
struct edge{
int e,next;
}ed[maxn*2];
int n,m,s,e,l,r,cc,ne=0,i,a[maxn],c[maxn][2],pre[maxn],sum[maxn],lcol[maxn],rcol[maxn],tag[maxn],col[maxn],d[maxn];
char ch;
void add(int s,int e){ed[++ne].e=e;ed[ne].next=a[s];a[s]=ne;}
void dfs(int x)
{
int j,i=0,to;
tag[x]=-1;lcol[x]=rcol[x]=col[x];
sum[x]=1;c[x][0]=c[x][1]=0;
for (j=a[x];j;j=ed[j].next)
if (ed[j].e!=pre[x])
{
to=ed[j].e;d[to]=d[x]+1;pre[to]=x;
dfs(to);
}
}
void update(int x)
{
sum[x]=sum[c[x][0]]+sum[c[x][1]]+1;
if (c[x][0]) lcol[x]=lcol[c[x][0]],sum[x]-=(rcol[c[x][0]]==col[x]); else lcol[x]=col[x];
if (c[x][1]) rcol[x]=rcol[c[x][1]],sum[x]-=(lcol[c[x][1]]==col[x]); else rcol[x]=col[x];
}
void mark(int x,int cc)
{
if (!x) return;
lcol[x]=rcol[x]=col[x]=cc;
sum[x]=1;tag[x]=cc;
}
void down(int x)
{
if (tag[x]!=-1) mark(c[x][0],tag[x]),mark(c[x][1],tag[x]),tag[x]=-1;
}
bool isroot(int x){return !pre[x]||(c[pre[x]][0]!=x&&c[pre[x]][1]!=x);}
void rot(int x,int kind)
{
int y=pre[x],z=pre[y];
down(y);down(x);
if (!isroot(y)&&z) c[z][c[z][1]==y]=x;
c[y][!kind]=c[x][kind];pre[c[x][kind]]=y;
c[x][kind]=y;pre[y]=x;
pre[x]=z;
update(y);update(x);
}
void splay(int x)
{
int y,z,kind;
while (!isroot(x))
{
y=pre[x];
if (isroot(y)) rot(x,c[y][0]==x);
else
{
int z=pre[y],kind=c[z][0]==y;
if (c[y][kind]==x) rot(x,!kind);else rot(y,kind);
rot(x,kind);
}
}
down(x);
}
void access(int x)
{
int u;
splay(x);
c[x][1]=0;update(x);
while (pre[x])
{
u=pre[x];splay(u);
c[u][1]=x;update(u);splay(x);
}
}
int lca(int x,int y)
{
access(y);
int u;
splay(x);
c[x][1]=0;
while (pre[x])
{
u=pre[x];splay(u);
if (pre[u]==0) return u;
c[u][1]=x;update(u);splay(x);
}
return x;
}
int getpre(int x)
{
splay(x);
if (!c[x][0]) return 0;
x=c[x][0];
while (c[x][1]) x=c[x][1];
return x;
}
void change(int x,int y,int cc)
{
if (d[x]<d[y]) swap(x,y);
int fa=lca(x,y),u;
u=getpre(fa);
if (!u) mark(fa,cc); else splay(u),mark(c[u][1],cc);
access(x);splay(fa);
if (!u) mark(fa,cc); else splay(u),mark(c[u][1],cc);
}
int query(int x,int y)
{
if (d[x]<d[y]) swap(x,y);
int fa=lca(x,y),u,ans=0;
u=getpre(fa);
if (!u) ans+=sum[fa]; else splay(u),ans+=sum[c[u][1]];
access(x);splay(fa);
if (!u) ans+=sum[fa]; else splay(u),ans+=sum[c[u][1]];
ans--;
return ans;
}
int main()
{
freopen("2243.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
memset(a,0,sizeof(a));
sum[0]=pre[0]=0;
for (i=1;i<=n;i++) scanf("%d",&col[i]);
for (i=1;i<n;i++)
{
scanf("%d%d",&s,&e);
add(s,e);add(e,s);
}
pre[1]=0;d[1]=1;dfs(1);
scanf("\n");
for (i=1;i<=m;i++)
{
scanf("%c%d%d",&ch,&l,&r);
if (ch=='Q') printf("%d\n",query(l,r));
else
{
scanf("%d",&cc);
change(l,r,cc);
}
scanf("\n");
}
fclose(stdin);
fclose(stdout);
}