这道题是点的树链剖分。
修改链,询问链。
树链剖分的意思就是将树线性化
树剖之后就变成区间修改和区间询问了。
询问的是颜色段数量,所以就是区间合并,并且树链合并时还要判断两个链的端点是否颜色相同。
(3个半小时,菜不成声)
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#define msc(X) memset(X,-1,sizeof(X))
using namespace std;
const int MAXN=1e5+5;
struct Edge{
int to,next;
}edge[MAXN<<1];
int tot,hd[MAXN];
int siz[MAXN],son[MAXN],fa[MAXN],deep[MAXN];
int p[MAXN],pos,top[MAXN];
void init(void)
{
tot=pos=0;
msc(hd);
msc(son);
}
void addedge(int u,int v)
{
edge[tot].to=v;
edge[tot].next=hd[u];
hd[u]=tot++;
}
void dfs1(int u,int pre,int d)
{
deep[u]=d;
fa[u]=pre;
siz[u]=1;
for(int i=hd[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(v!=pre){
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(son[u]==-1||siz[v]>siz[son[u]])
son[u]=v;
}
}
}
void getpos(int u,int spy)
{
top[u]=spy;
p[u]=pos++;
if(son[u]==-1) return;
getpos(son[u],spy);
for(int i=hd[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(v!=fa[u]&&v!=son[u])
getpos(v,v);
}
}
//SegTree
struct Tree{
int l,r,lclr,rclr,dc,lazy;
}tree[MAXN<<2];
int clr[MAXN];
void build(int tn,int l,int r)
{
tree[tn].l=l,tree[tn].r=r;
tree[tn].lclr=tree[tn].rclr=tree[tn].lazy=-1;
tree[tn].dc=0;
if(l==r) return;
int mid=(l+r)>>1;
build(tn<<1,l,mid);
build(tn<<1|1,mid+1,r);
}
void push_up(int tn)
{
tree[tn].lclr=tree[tn<<1].lclr;
tree[tn].rclr=tree[tn<<1|1].rclr;
tree[tn].dc=tree[tn<<1].dc+tree[tn<<1|1].dc-(tree[tn<<1].rclr==tree[tn<<1|1].lclr);
}
void push_down(int tn)
{
if(tree[tn].lazy!=-1){
tree[tn<<1].lazy=tree[tn<<1|1].lazy=tree[tn<<1].lclr=tree[tn<<1].rclr=tree[tn<<1|1].lclr=tree[tn<<1|1].rclr=tree[tn].lazy;
tree[tn<<1].dc=tree[tn<<1|1].dc=1;
tree[tn].lazy=-1;
}
}
void update(int tn,int l,int r,int value)
{
if(l<=tree[tn].l&&tree[tn].r<=r){
tree[tn].lclr=tree[tn].rclr=tree[tn].lazy=value;
tree[tn].dc=1;
return ;
}
push_down(tn);
int mid=(tree[tn].l+tree[tn].r)>>1;
if(l<=mid) update(tn<<1,l,r,value);
if(r>mid) update(tn<<1|1,l,r,value);
push_up(tn);
}
int sum(int tn,int l,int r)
{
if(l<=tree[tn].l&&tree[tn].r<=r)
return tree[tn].dc;
push_down(tn);
int mid=(tree[tn].l+tree[tn].r)>>1;
if(r<=mid) return sum(tn<<1,l,r);
else if(l>mid) return sum(tn<<1|1,l,r);
else return sum(tn<<1,l,mid)+sum(tn<<1|1,mid+1,r)-(tree[tn<<1].rclr==tree[tn<<1|1].lclr);
}
void chang(int u,int v,int value)
{
int f1=top[u],f2=top[v];
while(f1!=f2){
if(deep[f1]<deep[f2]){
swap(u,v);
swap(f1,f2);
}
update(1,p[top[u]],p[u],value);
u=fa[f1],f1=top[u];
}
if(deep[u]>deep[v]) swap(u,v);
update(1,p[u],p[v],value);
}
int color(int tn,int x)
{
if(tree[tn].lazy!=-1)
return tree[tn].lazy;
push_down(tn);
int mid=(tree[tn].l+tree[tn].r)>>1;
return x<=mid ? color(tn<<1,x) : color(tn<<1|1,x);
}
int query(int u,int v)
{
int f1=top[u],f2=top[v];
int rt=0;
//printf("f%d %d\n",f1,f2);
while(f1!=f2){
if(deep[f1]<deep[f2]){
swap(u,v);
swap(f1,f2);
}
rt+=sum(1,p[f1],p[u])-(color(1,p[f1])==color(1,p[fa[f1]]));
u=fa[f1],f1=top[u];
}
if(deep[u]>deep[v]) swap(u,v);
// printf("rt%d: pu: %d v %d color %d %d\n",rt,p[u],p[v],color(1,p[u]),color(1,p[v]));
return rt+sum(1,p[u],p[v]);
}
int main(void)
{
int n,m;
init();
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",clr+i);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
dfs1(1,0,0);
getpos(1,1);
build(1,0,pos-1);
for(int i=1;i<=n;i++)
update(1,p[i],p[i],clr[i]);
while(m--){
char op[3];
scanf("%s",op);
if(op[0]=='C'){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
chang(a,b,c);
}
else{
int a,b;
scanf("%d%d",&a,&b);
printf("%d\n",query(a,b));
}
}
return 0;
}