这题交了30多次,比这别人的代码写,最后才弄懂。记得还有一道网络赛的题与他类似。
思路:
change操作很简单,难点在于区间操作。区间的合并只是一个简单的线段树的区间合并,要注意的是在查询颜色段的时候需要类比于正向查询时,要反向查询fa[tpu]与tpu的颜色关系,如果相同需要-1。
在区间查询时候,要需要注意记录当前的最右侧颜色,因为需要用它来与下一次查询的最左侧值来做比较,如果同则-1,不断更新右侧
#include <map>
#include <set>
#include <queue>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
const int MAXN = 100010;
struct node {
int left,right;
int left_color,right_color;
int sum;
int lazy;
}t[MAXN*4];
struct Edge
{
int to,next;
} edge[MAXN*4];
int fa[MAXN],son[MAXN],siz[MAXN],dep[MAXN],top[MAXN],id[MAXN],val[MAXN];
int n,q;
int col[MAXN];
int topw= 0;
int tot ,head[MAXN];
int cnt=0;
void Init() {
topw = 0;
memset(head,-1,sizeof(head));
memset(son,0,sizeof(son));
}
void addedge(int u,int v )
{
edge[cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt++;
}
void dfs1(int u,int f,int d)
{
dep[u]=d;
siz[u]=1;
son[u]=0;
fa[u]=f;
for(int i =head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[u])
continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v])
son[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
id[u]=++topw;
if(son[u])
dfs2(son[u],tp);
for(int i=head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[u]||v==son[u])
continue;
dfs2(v,v);
}
}
void push_up(int i)
{
t[i].sum=t[i<<1].sum+t[i<<1|1].sum;
if(t[i<<1].right_color==t[i<<1|1].left_color)
t[i].sum--;
t[i].right_color=t[i<<1|1].right_color;
t[i].left_color=t[i<<1].left_color;
}
void push_down(int i)
{
if(t[i].lazy)
{
t[i<<1|1].lazy=t[i<<1].lazy=t[i].lazy;
t[i<<1].left_color=t[i<<1].right_color=t[i].lazy;
t[i<<1|1].left_color=t[i<<1|1].right_color=t[i].lazy;
t[i<<1|1].sum=t[i<<1].sum=t[i].sum;
t[i].lazy=0;
}
}
void build(int i,int left,int right)
{
t[i].left=left,t[i].right=right,t[i].lazy=0;
if(left==right)
{
t[i].left_color=val[left];
t[i].right_color=val[left];
t[i].sum=1;
return ;
}
int mid=(t[i].left+t[i].right)>>1;
build(i<<1,left,mid);
build(i<<1|1,mid+1,right);
push_up(i);
return ;
}
void update(int i,int left,int right,int w)
{
if(left<=t[i].left&&t[i].right<=right)
{
t[i].lazy=w;
t[i].left_color=t[i].right_color=w;
t[i].sum=1;
return ;
}
push_down(i);
int mid=(t[i].left+t[i].right)>>1;
if(mid>=left)
{
update(i<<1,left,right,w);
}
if(mid<right)
{
update(i<<1|1,left,right,w);
}
push_up(i);
}
int right_color = 0;
int ANS = 0;
void query(int i,int left,int right)
{
if(left<=t[i].left&&t[i].right<=right)
{
if(right_color==t[i].left_color)
ANS+=t[i].sum-1;
else
ANS+=t[i].sum;
right_color = t[i].right_color;
return;
}
push_down(i);
int mid=(t[i].left+t[i].right)>>1;
if(mid>=left)
query(i<<1,left,right);
if(right>mid)
query(i<<1|1,left,right);
push_up(i);
}
int query_one(int i,int aim,int w)
{
if(t[i].left==t[i].right&&t[i].left==aim)
{
if(w==0)
return t[i].left_color;
return t[i].right_color;
}
push_down(i);
int mid=(t[i].left+t[i].right)>>1;
if(mid>=aim)
return query_one(i<<1,aim,w);
else
return query_one(i<<1|1,aim,w);
}
void change(int u,int v,int w)
{
int tpu=top[u],tpv=top[v];
while(tpu!=tpv)
{
if(dep[tpu]<dep[tpv])
{
swap(u,v);
swap(tpu,tpv);
}
update(1,id[tpu],id[u],w);
u=fa[tpu];
tpu=top[u];
}
if(dep[u]>dep[v])
swap(u,v);
update(1,id[u],id[v],w);
}
int ask(int u,int v)
{
int ans=0;
int tpu=top[u],tpv=top[v];
while(tpu!=tpv)
{
if(dep[tpu]<dep[tpv])
{
swap(u,v);
swap(tpu,tpv);
}
ANS=0;
right_color=0;
query(1,id[tpu],id[u]);
ans+=ANS;
u=fa[tpu];
tpu=top[u];
}
if(dep[u]>dep[v])
swap(u,v);
ANS=0,right_color=0;
query(1,id[u],id[v]);
ans+=ANS;
return ans;
}
int sub(int u,int v)
{
int ans=0;
int lc,rc;
int _left_color,_right_color;
int tpu=top[u],tpv=top[v];
while(tpu!=tpv)
{
if(dep[tpu]<dep[tpv])
{
swap(u,v);
swap(tpu,tpv);
}
_left_color = query_one(1,id[tpu],0);
_right_color = query_one(1,id[fa[tpu]],1);
if(_left_color == _right_color)
ans --;
u=fa[tpu];
tpu=top[u];
}
return ans;
}
int main()
{
scanf("%d%d",&n,&q);
topw=0,cnt=0;
memset(son,0,sizeof(son));
memset(edge,0,sizeof(edge));
memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++)
scanf("%d",&col[i]);
int u,v,a,b,c;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
topw=0;
dfs1(1,0,1);
dfs2(1,1);
for(int i = 1 ; i <= n ; i++)
{
val[ id[i] ] = col[i];
}
build(1,1,topw);
char op[5];
int color;
while(q--)
{
scanf("%s",op) ;
if(op[0] == 'C')
{
scanf("%d %d %d",&u,&v,&color);
change(u,v,color);
}
else
{
scanf("%d %d",&u,&v);
int ans = 0;
ans += ask(u,v);
ans += sub(u,v);
printf("%d\n",ans);
}
}
}
网上另一种代码,另一种解法http://blog.csdn.net/accelerator_/article/details/39737769?utm_source=tuicool