调了三天,对着题解改代码……
什么叫超级蒟蒻啊?(战术后仰)
Luogu P2486 [SDOI2011]染色
题外话:不知为什么,虽然自己把树剖模板背的还是比较熟了,但每次做题时总要打错一些地方,总有那么些细节调不出来。最后还是要看题解,一个一个地比对才行。真担心考到树剖时自己能不能做出来……
一开始我以为这道题是维护区间颜色种类数,愣是不知道怎么维护(太蒻了),后来才看清楚是维护区间颜色段数,这样就只需判断两个区间相邻处的颜色是否相同,以更新答案。
在原有的树剖模板基础上加上lf[rt],rf[rt]两个数组表示线段树上rt节点包含的区间左端的颜色与右端的颜色,建树时即可维护出来。pushup更新父节点时,其左端点颜色用其左区间的左端点颜色更新,右区间亦然。
这是为了区间合并统计时,当左区间的右端与右区间的左端颜色相同时,统计答案减一。
注意,修改和询问操作都要判断左右区间的端点颜色从而得到答案。两个节点找lca,一个节点向上跳的时候也要判断重链顶端与其父亲节点的颜色是否相同只是因为询问时忘了判断就调了三天?
#include<bits/stdc++.h>
#define lson (rt<<1)
#define rson (rt<<1|1)
using namespace std;
const int maxn=1e5+5;
int n,m,cnt,tot;
int w[maxn],head[maxn],sum[maxn<<2],laz[maxn<<2],lf[maxn<<2],rf[maxn<<2];
int s[maxn],f[maxn],top[maxn],rev[maxn],siz[maxn],d[maxn],id[maxn];
struct edge
{
int v,nxt;
}e[maxn<<1];
void add(int u,int v)
{
e[++cnt].nxt=head[u];
e[cnt].v=v;
head[u]=cnt;
}
void dfs1(int u,int fa)
{
f[u]=fa;
d[u]=d[fa]+1;
siz[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[s[u]]) s[u]=v;
}
}
void dfs2(int u,int t)
{
id[u]=++tot;
rev[tot]=u;
top[u]=t;
if(!s[u]) return;
dfs2(s[u],t);
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].v;
if(v==f[u]||v==s[u]) continue;
dfs2(v,v);
}
}
void pushup(int rt)
{
sum[rt]=sum[lson]+sum[rson];
lf[rt]=lf[lson];rf[rt]=rf[rson];
if(rf[lson]==lf[rson]) --sum[rt];
return;
}
void pushdown(int rt)
{
if(laz[rt]==-1) return;
sum[rt]=sum[lson]=sum[rson]=1;
laz[lson]=laz[rson]=lf[lson]=rf[lson]=lf[rson]=rf[rson]=laz[rt];//注意把子区间的懒标记laz,断电颜色lf,rf都要更新,同时更新区间颜色段数量sum
laz[rt]=-1;return;
}
void build(int l,int r,int rt)
{
if(l==r)
{
lf[rt]=rf[rt]=w[rev[l]];//我写树剖常掉的坑,是w[rev[l]]而非w[l]
sum[rt]=1;
return;
}
int mid=(l+r)>>1;
build(l,mid,lson);
build(mid+1,r,rson);
pushup(rt);
}
void upd(int l,int r,int rt,int ql,int qr,int k)
{
if(ql<=l&&qr>=r)
{
sum[rt]=1;
lf[rt]=rf[rt]=laz[rt]=k;
return;
}
pushdown(rt);
int mid=(l+r)>>1;
if(ql<=mid) upd(l,mid,lson,ql,qr,k);
if(qr>mid) upd(mid+1,r,rson,ql,qr,k);
pushup(rt);
}
void update(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x,y);
upd(1,n,1,id[top[x]],id[x],k);//注意是从id[top[x]]到id[x],别搞反了
x=f[top[x]];
}
if(d[x]>d[y]) swap(x,y);
upd(1,n,1,id[x],id[y],k);
}
int que(int l,int r,int rt,int ql,int qr)
{
if(ql<=l&&qr>=r) return sum[rt];
pushdown(rt);
int mid=(l+r)>>1,num=0;
if(ql<=mid) num+=que(l,mid,lson,ql,qr);
if(qr>mid) num+=que(mid+1,r,rson,ql,qr);
if(ql<=mid&&qr>mid&&rf[lson]==lf[rson]&&r-l>1) --num;//左右两区间合并统计时才判断
return num;
}
int find(int l,int r,int rt,int p)
{
if(l==r) return lf[rt];
pushdown(rt);
int mid=(l+r)>>1;
if(p<=mid) return find(l,mid,lson,p);
else return find(mid+1,r,rson,p);
}
int query(int x,int y)
{
int ans=0,la=0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=que(1,n,1,id[top[x]],id[x]);
if(find(1,n,1,id[f[top[x]]])==find(1,n,1,id[top[x]])) --ans;
x=f[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=que(1,n,1,id[x],id[y]);
return ans?ans:1;
}
int main()
{
//freopen("input.txt","r",stdin);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int x,y,z,i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);build(1,tot,1);
memset(laz,-1,sizeof(laz));
char op[5];
for(int a,b,c,i=1;i<=m;i++)
{
scanf("%s",op);
if(op[0]=='C')
{
scanf("%d%d%d",&a,&b,&c);
update(a,b,c);
}
else
{
scanf("%d%d",&a,&b);
printf("%d\n",query(a,b));
}
}
return 0;
}