题目
题目大意:一棵树,每个点权是一个颜色。
支持两种操作:
1.修改任意一条链,使得这条链上的颜色均为c;
2.询问任意一条链上的颜色段数。
例如:11221 颜色段数为3
分析
大体思路
由于是链上操作,考虑树链剖分+线段树的做法
这两个操作都是区间修改,1操作直接普通的区间修改+dag标记即可实现
线段树合并细节处理
下面重点来讲讲2:
线段树的查询中,假设查询的区间为 [ l , r ] [l,r] [l,r]定义两个全局变量 l s t l 、 l a t r lstl、latr lstl、latr表示 c o l [ l ] 、 c o l [ r ] col[l]、col[r] col[l]、col[r]
那么在返回答案的时候就是:
if(t[rt].l>=l&&t[rt].r<=r){
if(t[rt].l==l)lstl=t[rt].lc;
if(t[rt].r==r)lstr=t[rt].rc;//就是这两行是新加的
return t[rt].sum;
}
注意,线段树里面是树上节点的dfs序。
这一步具体为什么要记录这两个信息,下面会讲。
树链查找细节处理
由于树链剖分的时候,xy两个点会你跳几下我跳几下,把x节点延伸出去的链称为x链,y同理,那么整条链就被分成了3个部分:x链、y链和中间部分。
x链来讲一端点已经固定,定义 a n s l ansl ansl表示另一个不固定的端点,定义 a n s r ansr ansr为y链上的不固定端点。
那么以x向上跳为例:
if(dep[fx]>=dep[fy]){
ans+=query(1,id[fx],id[x]);
x=f[fx];
fx=top[x];
if(lstr==ansl)ans--;
ansl=lstl;
}
如果当前 a n s l ansl ansl等于这个新增区间的 l s t r lstr lstr,即出现重合部分,那么答案减一
然后是跳到同一条重链之后,要分别考虑x处y处是否有颜色相同部分。
if(dep[x]<=dep[y]){
ans+=query(1,id[x],id[y]);
if(ansl==lstl)ans--;
if(ansr==lstr)ans--;
}
else{
ans+=query(1,id[y],id[x]);
if(ansl==lstr)ans--;
if(ansr==lstl)ans--;
}
代码
#include<bits/stdc++.h>
using namespace std;
int read(){
char s;
int x=0,f=1;
s=getchar();
while(s<'0'||s>'9'){
if(s=='-')f=-1;
s=getchar();
}
while(s>='0'&&s<='9'){
x*=10;
x+=s-'0';
s=getchar();
}
return x*f;
}
const int N=1e5+5;
int n,m;
int a[N];
vector<int>v[N];
int son[N],size[N],dep[N],f[N];
void dfs1(int x,int fa){
size[x]=1;
dep[x]=dep[fa]+1;
f[x]=fa;
int maxn=-1;
for(int i=0;i<v[x].size();i++){
int k=v[x][i];
if(k==fa)continue;
dfs1(k,x);
size[x]+=size[k];
if(size[k]>maxn){
maxn=size[k];
son[x]=k;
}
}
return;
}
int ti;
int top[N],id[N],b[N];
void dfs2(int x,int topf){
top[x]=topf;
id[x]=++ti;
b[id[x]]=a[x];
if(!son[x])return;
dfs2(son[x],topf);
for(int i=0;i<v[x].size();i++){
int k=v[x][i];
if(id[k])continue;
dfs2(k,k);
}
}
struct tree{
int l,r;
int sum;
int lc,rc;
int dag;
}t[N<<2];
void pushup(int rt){
t[rt].sum=t[rt<<1].sum+t[rt<<1|1].sum;
t[rt].lc=t[rt<<1].lc;
t[rt].rc=t[rt<<1|1].rc;
if(t[rt<<1].rc==t[rt<<1|1].lc)t[rt].sum-=1;
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
t[rt].dag=-1;
if(l==r){
t[rt].lc=t[rt].rc=b[l];
t[rt].sum=1;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
void pushdown(int rt){
if(t[rt].dag==-1)return;
t[rt<<1].dag=t[rt<<1|1].dag=t[rt].dag;
t[rt<<1].sum=t[rt<<1|1].sum=1;
t[rt<<1].lc=t[rt<<1].rc=t[rt<<1|1].lc=t[rt<<1|1].rc=t[rt].dag;
t[rt].dag=-1;
return;
}
void modify(int rt,int l,int r,int c){
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].sum=1;
t[rt].dag=c;
t[rt].lc=t[rt].rc=c;
return;
}
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
if(l<=mid)modify(rt<<1,l,r,c);
if(mid+1<=r)modify(rt<<1|1,l,r,c);
pushup(rt);
return;
}
void modify_link(int x,int y,int c){
int fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]>=dep[fy]){
modify(1,id[fx],id[x],c);
x=f[fx];
fx=top[x];
}
else{
modify(1,id[fy],id[y],c);
y=f[fy];
fy=top[y];
}
}
if(dep[x]<=dep[y])modify(1,id[x],id[y],c);
else modify(1,id[y],id[x],c);
return;
}
int lstl,lstr;//记录最后一段的颜色
int query(int rt,int l,int r){
if(t[rt].l>=l&&t[rt].r<=r){
if(t[rt].l==l)lstl=t[rt].lc;
if(t[rt].r==r)lstr=t[rt].rc;
return t[rt].sum;
}
int ans=0;
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
if(r<=mid)ans=query(rt<<1,l,r);
else if(l>mid)ans=query(rt<<1|1,l,r);
else ans=query(rt<<1,l,r)+query(rt<<1|1,l,r)-bool(t[rt<<1].rc==t[rt<<1|1].lc);
pushup(rt);
return ans;
}
int ansl,ansr;
int query_link(int x,int y){
int fx=top[x],fy=top[y];
int ans=0;
ansl=ansr=-1;//x\y对应区间的内部的颜色
while(fx!=fy){
if(dep[fx]>=dep[fy]){
ans+=query(1,id[fx],id[x]);
x=f[fx];
fx=top[x];
if(lstr==ansl)ans--;
ansl=lstl;
}
else{
ans+=query(1,id[fy],id[y]);
y=f[fy];
fy=top[y];
if(lstr==ansr)ans--;
ansr=lstl;
}
}
if(dep[x]<=dep[y]){
ans+=query(1,id[x],id[y]);
if(ansl==lstl)ans--;
if(ansr==lstr)ans--;
}
else{
ans+=query(1,id[y],id[x]);
if(ansl==lstr)ans--;
if(ansr==lstl)ans--;
}
return ans;
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
v[x].push_back(y);
v[y].push_back(x);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++){
char op;
cin>>op;
if(op=='C'){
int x=read(),y=read(),col=read();
modify_link(x,y,col);
}
else{
int x=read(),y=read();
printf("%lld\n",query_link(x,y));
}
}
}