树链剖分题
将树剖分后维护区间首尾的颜色
区间合并的时候若左区间尾颜色==右区间首颜色,则a[n].num=a[lch].num+a[rch].num-1;否则a[n].num=a[lch].num+a[rch].num;
然后就是注意在重链上跳累加答案的时候判一下该重链头和其父亲的颜色是否一致,一致的话sum--
/**************************************************************
Problem: 2243
User: syh0313
Language: C++
Result: Accepted
Time:4928 ms
Memory:33808 kb
****************************************************************/
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#define lch a[n].lc
#define rch a[n].rc
using
namespace
std;
const
int
maxn=2e5;
int
n,m,v[maxn],st[maxn],to[maxn],nt[maxn],topt,cnt,root;
int
dep[maxn],fa[maxn],rem[maxn],si[maxn];
int
dfn_num,dfn[maxn],line[maxn],top[maxn];
bool
f[maxn];
char
s[maxn];
struct
da{
int
lc,rc,num,head,tail,tag;}a[maxn<<2];
void
add(
int
x,
int
y)
{to[++topt]=y; nt[topt]=st[x]; st[x]=topt;}
void
dfs1(
int
x,
int
de)
{
dep[x]=de; f[x]=1; si[x]=1;
int
p=st[x],ma=0;
while
(p)
{
if
(!f[to[p]])
{
dfs1(to[p],de+1);
fa[to[p]]=x; si[x]+=si[to[p]];
if
(si[to[p]]>ma) {ma=si[to[p]]; rem[x]=to[p];}
}
p=nt[p];
}
}
void
dfs2(
int
x)
{
f[x]=1;
if
(rem[fa[x]]==x) top[x]=top[fa[x]];
else
top[x]=x;
dfn[x]=++dfn_num; line[dfn_num]=x;
int
p=st[x];
if
(rem[x]) dfs2(rem[x]);
while
(p)
{
if
(!f[to[p]]) dfs2(to[p]);
p=nt[p];
}
}
void
updata(
int
n)
{
if
(a[lch].tail==a[rch].head) a[n].num=a[lch].num+a[rch].num-1;
else
a[n].num=a[lch].num+a[rch].num;
a[n].head=a[lch].head; a[n].tail=a[rch].tail;
}
void
build_tree(
int
&n,
int
l,
int
r)
{
n=++topt;
if
(l==r)
{a[n].num=1; a[n].head=v[line[l]]; a[n].tail=v[line[l]];
return
;}
int
mid=(l+r)>>1;
build_tree(lch,l,mid); build_tree(rch,mid+1,r);
updata(n);
}
void
pushdown(
int
n)
{
if
(!a[n].tag)
return
;
a[lch].num=a[rch].num=1;
a[lch].head=a[lch].tail=a[rch].head=a[rch].tail=a[n].tag;
a[lch].tag=a[rch].tag=a[n].tag;
a[n].tag=0;
}
void
tree_fz(
int
n,
int
L,
int
R,
int
l,
int
r,
int
k)
{
if
(L==l && R==r)
{
a[n].tag=k; a[n].num=1; a[n].head=k; a[n].tail=k;
return
;
}
pushdown(n);
int
mid=(L+R)>>1;
if
(r<=mid) tree_fz(lch,L,mid,l,r,k);
else
if
(l>=mid+1) tree_fz(rch,mid+1,R,l,r,k);
else
tree_fz(lch,L,mid,l,mid,k),tree_fz(rch,mid+1,R,mid+1,r,k);
updata(n);
}
void
qfz(
int
x,
int
y,
int
k)
{
while
(top[x]!=top[y])
{
if
(dep[top[x]]<dep[top[y]]) swap(x,y);
tree_fz(root,1,n,dfn[top[x]],dfn[x],k);
x=fa[top[x]];
}
if
(dfn[x]>dfn[y]) swap(x,y);
tree_fz(root,1,n,dfn[x],dfn[y],k);
return
;
}
int
qury(
int
n,
int
L,
int
R,
int
l,
int
r)
{
if
(L==l && R==r)
return
a[n].num;
pushdown(n);
int
mid=(L+R)>>1;
if
(r<=mid)
return
qury(lch,L,mid,l,r);
else
if
(l>=mid+1)
return
qury(rch,mid+1,R,l,r);
else
{
if
(a[lch].tail==a[rch].head)
return
qury(lch,L,mid,l,mid)+qury(rch,mid+1,R,mid+1,r)-1;
return
qury(lch,L,mid,l,mid)+qury(rch,mid+1,R,mid+1,r);
}
updata(n);
}
int
qcol(
int
n,
int
l,
int
r,
int
lc)
{
if
(l==r && l==lc)
return
a[n].head;
pushdown(n);
int
mid=(l+r)>>1;
if
(lc<=mid)
return
qcol(lch,l,mid,lc);
else
return
qcol(rch,mid+1,r,lc);
updata(n);
}
int
qsum(
int
x,
int
y)
{
int
sum=0;
while
(top[x]!=top[y])
{
if
(dep[top[x]]<dep[top[y]]) swap(x,y);
sum+=qury(root,1,n,dfn[top[x]],dfn[x]);
if
(qcol(root,1,n,dfn[fa[top[x]]])==qcol(root,1,n,dfn[top[x]]))
sum--;
x=fa[top[x]];
}
if
(dfn[x]>dfn[y]) swap(x,y);
sum+=qury(root,1,n,dfn[x],dfn[y]);
return
sum;
}
int
main()
{
scanf
(
"%d%d"
,&n,&m);
for
(
int
i=1;i<=n;i++)
scanf
(
"%d"
,&v[i]);
for
(
int
i=1;i<n;i++)
{
int
xx,yy;
scanf
(
"%d%d"
,&xx,&yy);
add(xx,yy); add(yy,xx);
}
dfs1(1,1);
memset
(f,0,
sizeof
f); dfs2(1);
build_tree(root,1,n);
for
(
int
i=1;i<=m;i++)
{
int
l,r,xx;
scanf
(
"%s%d%d"
,s,&l,&r);
if
(s[0]==
'Q'
)
printf
(
"%d\n"
,qsum(l,r));
else
scanf
(
"%d"
,&xx),qfz(l,r,xx);
}
return
0;
}