题目
内容
有一棵n个节点的树,每个节点有颜色,支持两种操作:
- 把a到b这条链上的所有节点全部染成颜色c.
- 询问a到b这条链上的颜色段数量.
分析
树上的链操作摆明了要用树链剖分.用线段树来维护信息.
线段树的每个节点需要储存,区间的颜色数量,以及区间左端点的颜色和右端点的颜色.
以下是线段树中特殊操作
如何向上更新节点
先把左右儿子的颜色段数加起来.如果左儿子的右端点和右儿子的左端点是同一种颜色,那么颜色段数-1
void pushup(int k)
{
int lson=k<<1,rson=k<<1|1;
data[k]=data[lson]+data[rson];
cl[k]=cl[lson],cr[k]=cr[rson];
if(cr[lson]==cl[rson]) data[k]--;
}
如何询问区间
询问返回一个pair<int,pair<int,int>> 类型的元素, 表示区间的颜色段树,以及左右端点的颜色.
合并答案的时候同样按照向上更新节点的方式合并
pair<int,pii> query(int l,int r,int k)
{
pair<int,pii> ans,now;
ans.fi=ans.se.fi=ans.se.se=0;
if(L[k]>=l&&R[k]<=r)
{
ans.fi=data[k];
ans.se.fi=cl[k],ans.se.se=cr[k];
return ans;
}
pushdown(k);
int mid=(L[k]+R[k])>>1;
if(mid>=l)
{
ans=query(l,r,k<<1);
}
if(mid<r)
{
now=query(l,r,k<<1|1);
ans.fi+=now.fi;
if(ans.se.se==now.se.fi) ans.fi--;
if(ans.se.fi==0) ans.se.fi=now.se.fi;
ans.se.se=now.se.se;
}
return ans;
}
如何询问一条链
我们将链的端点定义问左节点和右节点.将左节点到lca定义为左链,将lca到右节点定义为左链.
分别计算左链和右链的答案,最后再合并到一起.
pair<int,pii> getsum(int l,int r)
{
pair<int,pii> ansl,ansr,now;
ansl.fi=ansl.se.fi=ansl.se.se=0;
ansr.fi=ansr.se.fi=ansr.se.se=0;
while(tp[l]!=tp[r])
{
if(dep[tp[l]]>dep[tp[r]])
{
now=tree.query(id[tp[l]],id[l],1);
ansl.fi+=now.fi;
if(ansl.se.fi==now.se.se) ansl.fi--;
if(ansl.se.se==0) ansl.se.se=now.se.se;
ansl.se.fi=now.se.fi;
l=fa[tp[l]];
}
else
{
now=tree.query(id[tp[r]],id[r],1);
ansr.fi+=now.fi;
if(ansr.se.fi==now.se.se) ansr.fi--;
if(ansr.se.se==0) ansr.se.se=now.se.se;
ansr.se.fi=now.se.fi;
r=fa[tp[r]];
}
}
if(dep[l]>dep[r])
{
now=tree.query(id[r],id[l],1);
ansl.fi+=now.fi;
if(ansl.se.fi==now.se.se) ansl.fi--;
if(ansl.se.se==0) ansl.se.se=now.se.se;
ansl.se.fi=now.se.fi;
}
else
{
now=tree.query(id[l],id[r],1);
ansr.fi+=now.fi;
if(ansr.se.fi==now.se.se) ansr.fi--;
if(ansr.se.se==0) ansr.se.se=now.se.se;
ansr.se.fi=now.se.fi;
}
ansl.fi+=ansr.fi;
if(ansl.se.fi==ansr.se.fi) ansl.fi--;
return ansl;
}
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
#define pb push_back
#define fi first
#define se second
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);i>=(b);i--)
const int maxn=1e5+10;
const int mod=1e9+7;
const int inf=0x3f3f3f3f;
int dep[maxn],son[maxn],size[maxn],tp[maxn],fa[maxn],book[maxn];
int id[maxn],rk[maxn],cnt=0;
vector<int> g[maxn];
void dfs1(int now,int f,int d)
{
dep[now]=d,fa[now]=f,size[now]=1;
for(auto x:g[now])
{
if(x==f) continue;
dfs1(x,now,d+1);
size[now]+=size[x];
if(size[x]>size[son[now]]) son[now]=x;
}
}
void dfs2(int now,int t)
{
tp[now]=t,id[now]=++cnt,rk[cnt]=now;
if(!son[now]) return ;
dfs2(son[now],t);
for(auto x: g[now])
{
if(x==fa[now]||x==son[now]) continue;
dfs2(x,x);
}
}
struct SegmentTree
{
int data[maxn<<2],lz[maxn<<2],L[maxn<<2],R[maxn<<2],cl[maxn<<2],cr[maxn<<2];
void pushup(int k)
{
int lson=k<<1,rson=k<<1|1;
data[k]=data[lson]+data[rson];
cl[k]=cl[lson],cr[k]=cr[rson];
if(cr[lson]==cl[rson]) data[k]--;
}
void build(int l,int r,int k)
{
L[k]=l,R[k]=r,lz[k]=0;
if(l==r)
{
data[k]=1;
cl[k]=cr[k]=book[rk[l]];
return ;
}
int mid=(l+r)>>1;
build(l,mid,k<<1);
build(mid+1,r,k<<1|1);
pushup(k);
}
void pushdown(int k)
{
if(lz[k]==0) return ;
int lson=k<<1,rson=k<<1|1;
lz[lson]=lz[k],lz[rson]=lz[k];
data[lson]=1,data[rson]=1;
cl[lson]=cr[lson]=cl[rson]=cr[rson]=lz[k];
lz[k]=0;
}
void change(int l,int r,int k,int x)
{
if(L[k]>=l&&R[k]<=r)
{
data[k]=1;
cl[k]=cr[k]=x;
lz[k]=x;
return ;
}
pushdown(k);
int mid=(L[k]+R[k])>>1;
if(mid>=l) change(l,r,k<<1,x);
if(mid<r) change(l,r,k<<1|1,x);
pushup(k);
}
pair<int,pii> query(int l,int r,int k)
{
pair<int,pii> ans,now;
ans.fi=ans.se.fi=ans.se.se=0;
if(L[k]>=l&&R[k]<=r)
{
ans.fi=data[k];
ans.se.fi=cl[k],ans.se.se=cr[k];
return ans;
}
pushdown(k);
int mid=(L[k]+R[k])>>1;
if(mid>=l)
{
ans=query(l,r,k<<1);
}
if(mid<r)
{
now=query(l,r,k<<1|1);
ans.fi+=now.fi;
if(ans.se.se==now.se.fi) ans.fi--;
if(ans.se.fi==0) ans.se.fi=now.se.fi;
ans.se.se=now.se.se;
}
return ans;
}
}tree;
void output(pair<int,pii>x)
{
cout<<x.fi<<' '<<x.se.fi<<' '<<x.se.se<<'\n';
}
pair<int,pii> getsum(int l,int r)
{
pair<int,pii> ansl,ansr,now;
ansl.fi=ansl.se.fi=ansl.se.se=0;
ansr.fi=ansr.se.fi=ansr.se.se=0;
while(tp[l]!=tp[r])
{
if(dep[tp[l]]>dep[tp[r]])
{
now=tree.query(id[tp[l]],id[l],1);
ansl.fi+=now.fi;
if(ansl.se.fi==now.se.se) ansl.fi--;
if(ansl.se.se==0) ansl.se.se=now.se.se;
ansl.se.fi=now.se.fi;
//cout<<l<<' '<<tp[l]<<'\n';
//output(ansl);
l=fa[tp[l]];
}
else
{
now=tree.query(id[tp[r]],id[r],1);
ansr.fi+=now.fi;
if(ansr.se.fi==now.se.se) ansr.fi--;
if(ansr.se.se==0) ansr.se.se=now.se.se;
ansr.se.fi=now.se.fi;
//cout<<r<<' '<<tp[r]<<'\n';
//output(ansr);
r=fa[tp[r]];
}
}
if(dep[l]>dep[r])
{
now=tree.query(id[r],id[l],1);
ansl.fi+=now.fi;
if(ansl.se.fi==now.se.se) ansl.fi--;
if(ansl.se.se==0) ansl.se.se=now.se.se;
ansl.se.fi=now.se.fi;
}
else
{
now=tree.query(id[l],id[r],1);
ansr.fi+=now.fi;
//output(now);
//output(ansr);
if(ansr.se.fi==now.se.se) ansr.fi--;
if(ansr.se.se==0) ansr.se.se=now.se.se;
ansr.se.fi=now.se.fi;
//output(ansr);
}
ansl.fi+=ansr.fi;
if(ansl.se.fi==ansr.se.fi) ansl.fi--;
return ansl;
}
void change(int l,int r,int c)
{
while(tp[l]!=tp[r])
{
if(dep[tp[l]]<dep[tp[r]]) swap(l,r);
tree.change(id[tp[l]],id[l],1,c);
l=fa[tp[l]];
}
if(dep[l]>dep[r]) swap(l,r);
tree.change(id[l],id[r],1,c);
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);
int n,m;
cin>>n>>m;
rep(i,1,n) cin>>book[i];
rep(i,1,n-1)
{
int u,v;
cin>>u>>v;
g[u].pb(v);
g[v].pb(u);
}
dfs1(1,0,1);
dfs2(1,1);
tree.build(1,n,1);
while(m--)
{
char ty;
int a,b,c;
cin>>ty;
if(ty=='C')
{
cin>>a>>b>>c;
change(a,b,c);
}
else if(ty=='Q')
{
cin>>a>>b;
pair<int,pii> now;
now=getsum(a,b);
cout<<now.fi<<'\n';
}
}
return 0;
}