树链剖分加线段树区间更新
线段树处理是模板形,重点是树链的合并
记录路径所经过的每条链的起始点和终点和他们此时的颜色
然后将记录的点扫一遍如果它与父亲在这条路径上且不在同一条链且颜色相同结果减一
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<stack>
#define lson i<<1
#define rson (i<<1)+1
#define maxn 100005
using namespace std;
int n,m;
int son[maxn],prin[maxn],top[maxn],deep[maxn],fa[maxn],siz[maxn],ha[maxn];
int last[maxn*2],Link[maxn*2],fron[maxn*2],edgel;
int sum[maxn*4],lco[maxn*4],rco[maxn*4],l,va[maxn];
int flag[maxn],preLco,preRco;
int st[maxn];
void add(int a,int b)
{
Link[edgel] = b;
fron[edgel] = last[a];
last[a] = edgel;
edgel++;
}
void dfs(int pre)
{
son[pre] = 0,siz[pre] = 1;
for(int i=last[pre];i!=-1;i=fron[i])
if(fa[pre]!=Link[i])
{
fa[Link[i]] = pre;
deep[Link[i]] = deep[pre]+1;
dfs(Link[i]);
if(siz[Link[i]]>siz[son[pre]])son[pre] = Link[i];
siz[pre]+=siz[Link[i]];
}
}
void dfs2(int pre,int root)
{
top[pre] = root;
ha[pre] = ++l;
va[l] = prin[pre];
if(son[pre])dfs2(son[pre],root);
for(int i=last[pre];i!=-1;i=fron[i])
if(Link[i]!=son[pre]&&Link[i]!=fa[pre])dfs2(Link[i],Link[i]);
}
void build(int i,int l,int r)
{
if(l==r)
{
sum[i] = 1;
lco[i] = va[l];
rco[i] = va[l];
return ;
}
int mid = (l+r)/2;
build(lson,l,mid);
build(rson,mid+1,r);
sum[i] = sum[lson]+sum[rson];
if(rco[lson]==lco[rson])sum[i]--;
lco[i] = lco[lson];
rco[i] = rco[rson];
}
void pushUp(int i)
{
sum[i] = sum[lson]+sum[rson];
if(rco[lson]==lco[rson])sum[i]--;
lco[i] = lco[lson];
rco[i] = rco[rson];
}
void pushDown(int i)
{
if(sum[i]==1)
{
sum[lson] = 1;
sum[rson] = 1;
lco[lson] = lco[i];
lco[rson] = lco[i];
rco[lson] = lco[i];
rco[rson] = lco[i];
}
}
void update(int i,int l,int r,int L,int R,int c)
{
if(l==L&&r==R)
{
sum[i] = 1;
lco[i] = c;
rco[i] = c;
return ;
}
pushDown(i);
int mid = (l+r)/2;
if(R<=mid)update(lson,l,mid,L,R,c);
else if(L>mid)update(rson,mid+1,r,L,R,c);
else {
update(lson,l,mid,L,mid,c);
update(rson,mid+1,r,mid+1,R,c);
}
pushUp(i);
}
int query(int i,int l,int r,int L,int R,int L1,int R1)
{
if(l==L&&r==R)
{
if(l==L1)preLco = lco[i];
if(r==R1)preRco = rco[i];
return sum[i];
}
pushDown(i);
int mid = (l+r)/2;
if(R<=mid)return query(lson,l,mid,L,R,L1,R1);
else if(L>mid)return query(rson,mid+1,r,L,R,L1,R1);
else {
int tmp = query(lson,l,mid,L,mid,L1,R1)+query(rson,mid+1,r,mid+1,R,L1,R1);
if(rco[lson]==lco[rson])tmp--;
return tmp;
}
}
void change(int a,int b,int c)
{
int faa = top[a],fab = top[b];
while(faa!=fab)
{
if(deep[faa]<deep[fab])
{
swap(faa,fab);
swap(a,b);
}
update(1,1,l,ha[faa],ha[a],c);
a = fa[faa];
faa = top[a];
}
if(deep[a]<deep[b])swap(a,b);
update(1,1,l,ha[b],ha[a],c);
}
int find(int a,int b)
{
int faa = top[a],fab = top[b],ans = 0,l1 = 0;
while(faa!=fab)
{
if(deep[faa]<deep[fab])
{
swap(faa,fab);
swap(a,b);
}
ans+=query(1,1,l,ha[faa],ha[a],ha[faa],ha[a]);
st[l1] = faa;
l1++;
if(faa!=a)
{
st[l1] = a;
l1++;
}
flag[faa] = preLco;
flag[a] = preRco;
a = fa[faa];
faa = top[a];
}
if(deep[a]<deep[b])swap(a,b);
ans+=query(1,1,l,ha[b],ha[a],ha[b],ha[a]);
st[l1] = b;
l1++;
if(b!=a)
{
st[l1] = a;
l1++;
}
flag[b] = preLco;
flag[a] = preRco;
for(int i=0;i<l1;i++)
{
int pre = st[i];
if(flag[pre]==flag[fa[pre]]&&top[pre]!=top[fa[pre]])ans--;
}
for(int i=0;i<l1;i++)flag[st[i]] = -1;
return ans;
}
int main()
{
scanf("%d %d",&n,&m);
memset(last,-1,sizeof(last));
memset(flag,-1,sizeof(flag));
l = 0,edgel = 0;
fa[1] = 0;deep[1] = 1;
for(int i=1;i<=n;i++)scanf("%d",&prin[i]);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d %d",&a,&b);
add(a,b);
add(b,a);
}
dfs(1);
dfs2(1,1);
build(1,1,l);
while(m--)
{
char s[5];
int a,b,c;
scanf("%s",s);
if(s[0]=='Q')
{
scanf("%d %d",&a,&b);
printf("%d\n",find(a,b));
}
else {
scanf("%d %d %d",&a,&b,&c);
change(a,b,c);
}
}
return 0;
}