学习了一下树链剖分。
相关文章:http://blog.sina.com.cn/s/blog_7a1746820100wp67.html
这题差不多算是模板题了吧。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 100005
struct SegTree
{
int l,r;
int lc,rc;
int num;
} tree[maxn<<2];
int size[maxn],fa[maxn],dep[maxn],son[maxn];
int top[maxn],tid[maxn],label;
int head[maxn],to[maxn<<1],next[maxn<<1],edge,e[maxn][3];
int root;
int color[maxn],p[maxn];
int LC,RC;
void init()
{
root=1;
memset(head,-1,sizeof(head));
fa[root]=dep[root]=label=edge=0;
}
inline void add(int u,int v)
{
to[edge]=v,next[edge]=head[u],head[u]=edge++;
}
inline void pushup(int rt)
{
int ls=rt<<1,rs=ls+1;
tree[rt].lc=tree[ls].lc;
tree[rt].rc=tree[rs].rc;
tree[rt].num=tree[ls].num+tree[rs].num;
if(tree[ls].rc==tree[rs].lc)
tree[rt].num--;
}
inline void pushdown(int rt)
{
if(tree[rt].num==1)
{
int ls=rt<<1,rs=ls+1;
tree[ls].num=tree[rs].num=1;
tree[ls].lc=tree[rs].lc=tree[ls].rc=tree[rs].rc=tree[rt].lc;
}
}
void build(int rt,int l,int r)
{
tree[rt].l=l,tree[rt].r=r;
if(l==r)
{
tree[rt].lc=tree[rt].rc=color[p[l]];
tree[rt].num=1;
return ;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void Find(int x)
{
size[x]=1,son[x]=0;
for(int i=head[x],v; ~i; i=next[i])
{
v=to[i];
if(v==fa[x]) continue;
fa[v]=x;
dep[v]=dep[x]+1;
Find(v);
size[x]+=size[v];
if(size[v]>size[son[x]])
son[x]=v;
}
}
void Connect(int x,int ancestor)
{
tid[x]=++label,top[x]=ancestor;
if(son[x]) Connect(son[x],ancestor);
for(int i=head[x],v; ~i; i=next[i])
{
v=to[i];
if(v==son[x]||v==fa[x]) continue;
Connect(v,v);
}
}
void update(int rt,int l,int r,int val)
{
if(tree[rt].l>=l&&tree[rt].r<=r)
{
tree[rt].num=1;
tree[rt].lc=tree[rt].rc=val;
return ;
}
int mid=(tree[rt].l+tree[rt].r)>>1;
pushdown(rt);
if(l<=mid) update(rt<<1,l,r,val);
if(r>mid) update(rt<<1|1,l,r,val);
pushup(rt);
}
int getsum(int rt,int l,int r)
{
if(l==tree[rt].l) LC=tree[rt].lc;
if(r==tree[rt].r) RC=tree[rt].rc;
if(tree[rt].l>=l&&tree[rt].r<=r)
return tree[rt].num;
int mid=(tree[rt].l+tree[rt].r)>>1;
pushdown(rt);
int ls=rt<<1,rs=ls+1;
int res=0,c1=-1,c2=-1;
if(l<=mid) res+=getsum(ls,l,r),c1=tree[ls].rc;
if(r>mid) res+=getsum(rs,l,r),c2=tree[rs].lc;
if(c1==c2&&c1!=-1) res--;
return res;
}
void change(int x,int y,int d)
{
int gx=top[x],gy=top[y];
while(gx!=gy)
{
if(dep[gx]<dep[gy])
{
swap(gx,gy);
swap(x,y);
}
update(1,tid[gx],tid[x],d);
x=fa[gx];
gx=top[x];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,tid[x],tid[y],d);
}
int query(int x,int y)
{
int gx=top[x],gy=top[y];
int res=0;
int cx=-1,cy=-1;
while(gx!=gy)
{
if(dep[gx]<dep[gy])
{
swap(gx,gy);
swap(x,y);
swap(cx,cy);
}
res+=getsum(1,tid[gx],tid[x]);
if(cx==RC) res--;
cx=LC;
x=fa[gx];
gx=top[x];
}
if(dep[x]>dep[y]) swap(x,y),swap(cx,cy);
res+=getsum(1,tid[x],tid[y]);
if(cx==LC) res--;
if(cy==RC) res--;
return res;
}
int main()
{
int n,m;
char op[5];
while(~scanf("%d%d",&n,&m))
{
init();
int a,b,c;
for(int i=1; i<=n; i++) scanf("%d",&color[i]);
for(int i=1; i<n; i++)
{
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
Find(root);
Connect(root,root);
for(int i=1; i<=n; i++) p[tid[i]]=i;
build(1,1,label);
while(m--)
{
scanf("%s",op);
if(op[0]=='Q')
{
scanf("%d%d",&a,&b);
printf("%d\n",query(a,b));
}
else
{
scanf("%d%d%d",&a,&b,&c);
change(a,b,c);
}
}
}
return 0;
}