Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
题解
树链剖分,线段树要打lazy-tag。一定要细心,注意函数及时退出。
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<cmath>
#define N 100002
using namespace std;
int n,m,zz,head[N],v[N];
struct bian {int to,nx;} e[N*2];
int h[N],fa[N][17],son[N],vis[N];
int size,bl[N],tw[N];
struct shu {int l,r,s,lc,rc,tag;} tr[4*N];
void insert(int x,int y)
{
zz++; e[zz].to=y; e[zz].nx=head[x]; head[x]=zz;
zz++; e[zz].to=x; e[zz].nx=head[y]; head[y]=zz;
}
void init()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&v[i]);
for(int i=1;i<n;i++)
{int x,y; scanf("%d%d",&x,&y); insert(x,y);}
}
void dfs1(int x) //2^17=131072
{
vis[x]=son[x]=1;
for(int i=1;i<=16;i++)
{if(h[x]<(1<<i)) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for(int i=head[x];i;i=e[i].nx)
{if(vis[e[i].to]) continue;
h[e[i].to]=h[x]+1; fa[e[i].to][0]=x;
dfs1(e[i].to);
son[x]+=son[e[i].to];
}
}
void dfs2(int x,int l)
{
size++;
bl[x]=l; tw[x]=size;
int k=0;
for(int i=head[x];i;i=e[i].nx)
{if(h[e[i].to]>h[x]&&son[e[i].to]>son[k]) k=e[i].to;}
if(k==0) return;
dfs2(k,l);
for(int i=head[x];i;i=e[i].nx)
{if(h[e[i].to]>h[x]&&e[i].to!=k) dfs2(e[i].to,e[i].to);}
}
int lca(int x,int y)
{
if(h[x]<h[y]) swap(x,y);
int t=h[x]-h[y];
for(int i=0;i<=16;i++)
{if(t&(1<<i)) x=fa[x][i];}
for(int i=16;i>=0;i--)
{if(fa[x][i]!=fa[y][i])
{x=fa[x][i]; y=fa[y][i];}
}
if(x==y) return x;
return fa[x][0];
}
void build(int w,int l,int r)
{
tr[w].l=l; tr[w].r=r; tr[w].tag=-1;
if(l==r) return;
int mid=(l+r)>>1;
build(w<<1,l,mid); build((w<<1)+1,mid+1,r);
}
void down(int w)
{
int tg=tr[w].tag; tr[w].tag=-1;
if(tg==-1||tr[w].l==tr[w].r) return;
tr[w<<1].tag=tr[(w<<1)+1].tag=tg;
tr[w<<1].s=tr[(w<<1)+1].s=1;
tr[w<<1].lc=tr[(w<<1)+1].lc=tg;
tr[w<<1].rc=tr[(w<<1)+1].rc=tg;
}
void up(int w)
{
int j=1;
if(tr[w<<1].rc!=tr[(w<<1)+1].lc) j=0;
tr[w].s=tr[w<<1].s+tr[(w<<1)+1].s-j;
tr[w].lc=tr[w<<1].lc; tr[w].rc=tr[(w<<1)+1].rc;
}
void change(int w,int x,int y,int c)
{
down(w);
int l=tr[w].l,r=tr[w].r;
if(l==x&&r==y) {tr[w].lc=tr[w].rc=c; tr[w].s=1; tr[w].tag=c; return ;}
int mid=(l+r)>>1;
if(mid>=y) change(w<<1,x,y,c);
else if(mid<x) change((w<<1)+1,x,y,c);
else {change(w<<1,x,mid,c); change((w<<1)+1,mid+1,y,c);}
up(w);
}
int find(int w,int x,int y)
{
down(w);
int l=tr[w].l,r=tr[w].r;
if(x==l&&y==r) return tr[w].s;
int mid=(l+r)>>1;
if(mid>=y) return find(w<<1,x,y);
else if(mid<x) return find((w<<1)+1,x,y);
else
{int j=1;
if(tr[w<<1].rc!=tr[(w<<1)+1].lc) j=0;
return find(w<<1,x,mid)+find((w<<1)+1,mid+1,y)-j;
}
}
int getc(int w,int x)
{
down(w);
int l=tr[w].l,r=tr[w].r;
if(l==r) return tr[w].lc;
int mid=(l+r)>>1;
if(mid>=x) return getc(w<<1,x);
else return getc((w<<1)+1,x);
}
int ask(int x,int y)
{
int sum=0;
while(bl[x]!=bl[y])
{sum+=find(1,tw[bl[x]],tw[x]);
if(getc(1,tw[bl[x]])==getc(1,tw[fa[bl[x]][0]])) sum--;
x=fa[bl[x]][0];
}
sum+=find(1,tw[y],tw[x]);
return sum;
}
void turn(int x,int y,int c)
{
while(bl[x]!=bl[y])
{change(1,tw[bl[x]],tw[x],c);
x=fa[bl[x]][0];
}
change(1,tw[y],tw[x],c);
}
void work()
{
char ch[5];
build(1,1,n);
for(int i=1;i<=n;i++) change(1,tw[i],tw[i],v[i]);
for(int i=1;i<=m;i++)
{scanf("%s",ch);
int x,y,z,t;
if(ch[0]=='Q')
{scanf("%d%d",&x,&y);
t=lca(x,y);
printf("%d\n",ask(x,t)+ask(y,t)-1);
}
else
{scanf("%d%d%d",&x,&y,&z);
t=lca(x,y);
turn(x,t,z); turn(y,t,z);
}
}
}
int main()
{
init(); dfs1(1); dfs2(1,1); work();
return 0;
}