题解:
线段树节点维护区间最左边的颜色,最右边颜色,总段,然后查询时加入所有链的段数,然后根据颜色判断哪俩链多算了一段。
大水题。
本来以为20min思路清晰地打完一发180行代码很屌。
……
俩小错误。(没有任何可以借鉴的价值)
一个是EDIT(note<<1|1,l,r,x),写成了EDIT(note<<1|1,r,r,x),
一个是有一个地方忘了pushdown了。
写了好久,调了好久,拍了好久,WA了好久。什么 状态啊。
给代码、数据生成、拍子
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 101000
#define inf 0x3f3f3f3f
using namespace std;
struct Fiona
{
int u,v,len,next;
}e[N<<1];
int head[N],cnt;
void add(int u,int v)
{
++cnt;
e[cnt].u=u;
e[cnt].v=v;
e[cnt].next=head[u];
head[u]=cnt;
}
struct Segment_Tree
{
int l,r,first,last,sum,flag;
}s[N<<2];
int fa[N],son[N],top[N],pos[N],deep[N];
int dfs1(int x,int p)
{
int i,v,sum=1,ax=-inf;
deep[x]=deep[p]+1;
fa[x]=p;
for(i=head[x];i;i=e[i].next)
{
v=e[i].v;
if(v==p)continue;
int temp=dfs1(v,x);
sum+=temp;
if(ax<temp)ax=temp,son[x]=v;
}
return sum;
}
void dfs2(int x,int p,int r)
{
int i,v;
pos[x]=++cnt;
top[x]=r;
if(son[x])dfs2(son[x],x,r);
for(i=head[x];i;i=e[i].next)
{
v=e[i].v;
if(v!=p&&v!=son[x])dfs2(v,x,v);
}
}
int color[N],ttt[N];
int n,m,root=1;
char opt[5];
void pushup(int x)
{
s[x].sum=s[x<<1].sum+s[x<<1|1].sum;
s[x].first=s[x<<1].first,s[x].last=s[x<<1|1].last;
if(s[x<<1].last==s[x<<1|1].first)s[x].sum--;
}
void pushdown(int x)
{
if(s[x].flag)
{
if(s[x].l<s[x].r)
{
s[x<<1].flag=s[x].flag;
s[x<<1|1].flag=s[x].flag;
}
s[x].first=s[x].last=s[x].flag;
s[x].sum=1;
s[x].flag=0;
}
}
void build(int note,int l,int r)
{
s[note].l=l,s[note].r=r;
if(l==r)
{
s[note].sum=1;
s[note].first=s[note].last=color[l];
return ;
}
int mid=l+r>>1;
build(note<<1,l,mid);
build(note<<1|1,mid+1,r);
pushup(note);
return ;
}
int find(int note,int x)
{
pushdown(note);
if(s[note].l==s[note].r)return s[note].first;
int mid=s[note].l+s[note].r>>1;
if(x<=mid)return find(note<<1,x);
else return find(note<<1|1,x);
}
int ASK(int note,int l,int r)
{
pushdown(note);
if(s[note].l==l&&r==s[note].r)return s[note].sum;
int mid=s[note].l+s[note].r>>1;
if(r<=mid)return ASK(note<<1,l,r);
else if(l>mid)return ASK(note<<1|1,l,r);
else
{
int ans=ASK(note<<1,l,mid)+ASK(note<<1|1,mid+1,r);
if(s[note<<1].last==s[note<<1|1].first)ans--;
return ans;
}
}
int ask(int a,int b)
{
int A,B,ans=0;
for(A=top[a],B=top[b];A!=B;a=fa[A],A=top[a])
{
if(deep[A]<deep[B])swap(A,B),swap(a,b);
ans+=ASK(1,pos[A],pos[a]);
if(find(1,pos[A])==find(1,pos[fa[A]]))ans--;
}
if(deep[a]<deep[b])swap(a,b);
ans+=ASK(1,pos[b],pos[a]);
return ans;
}
void EDIT(int note,int l,int r,int x)
{
if(s[note].l==l&&r==s[note].r)
{
s[note].flag=x;
pushdown(note);
return ;
}
pushdown(note);
int mid=s[note].l+s[note].r>>1;
if(r<=mid)EDIT(note<<1,l,r,x),pushdown(note<<1|1);
else if(l>mid)EDIT(note<<1|1,l,r,x),pushdown(note<<1);
else EDIT(note<<1,l,mid,x),EDIT(note<<1|1,mid+1,r,x);
pushup(note);
}
void edit(int a,int b,int c)
{
int A,B;
for(A=top[a],B=top[b];A!=B;a=fa[A],A=top[a])
{
if(deep[A]<deep[B])swap(A,B),swap(a,b);
EDIT(1,pos[A],pos[a],c);
}
if(deep[a]<deep[b])swap(a,b);
EDIT(1,pos[b],pos[a],c);
}
int main()
{
// freopen("test.in","r",stdin);
// freopen("my.out","w",stdout);
int i,a,b,c;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)scanf("%d",&ttt[i]);
for(i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
cnt=0;
dfs1(root,0);
dfs2(root,0,root);
for(i=1;i<=n;i++)color[pos[i]]=ttt[i];
build(1,1,n);
for(i=1;i<=m;i++)
{
scanf("%s",opt);
if(opt[0]=='C')
{
scanf("%d%d%d",&a,&b,&c);
edit(a,b,c);
}
else
{
scanf("%d%d",&a,&b);
printf("%d\n",ask(a,b));
}
}
// fclose(stdin);
// fclose(stdout);
return 0;
}
数据:
#include <ctime>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
int main()
{
int i,j,k;
freopen("test.in","w",stdout);
srand((unsigned)time(NULL));
int n,m;
n=5;
m=10;
printf("%d %d\n",n,m);
for(i=1;i<=n;i++)printf("%d ",rand()%4+1);
puts("");
for(i=1;i<n;i++)
{
printf("%d %d\n",i+1,rand()%i+1);
}
for(i=1;i<=m;i++)
{
int opt=rand()%2;
if(opt==1)
{
int a=0,b=0;
while(a==b)a=rand()%n+1,b=rand()%n+1;
printf("C %d %d %d\n",a,b,rand()%4+1);
}
else
{
int a=0,b=0;
while(a==b)a=rand()%n+1,b=rand()%n+1;
printf("Q %d %d\n",a,b);
}
}
fclose(stdout);
return 0;
}
拍子:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
int main()
{
int i,j,k;
for(i=1;i<=3000;i++)
{
system("rand");
system("my");
system("std");
printf("Case %04d : ",i);
if(system("fc my.out std.out > NULL")==0)
{
puts("AC");
}
else
{
puts("WAWAWWAWAWWWA");
system("pause");
return 0;
}
}
system("pause");
return 0;
}