2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 6661 Solved: 2439
[ Submit][ Status][ Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
学长讲的树链剖分例题,最近在学lct,发现这题lct可以做的样子,就先写的lct,不过其实树链剖分更简单些。
思路:记录链左端和右端的颜色,pushup及合并时,若左子树的右端点颜色与右子树的左端点相同,则num=numl+numr-1否则num=numl+numr。
两份代码附上:
1.lct代码:
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<string>
#include<climits>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define N 100010
using namespace std;
typedef long long ll;
ll read()
{
ll x=0,f=1;char ch;
while(ch<'0'||ch>'9'){
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
ll n,m,c[N],v[N];
ll tr[N][2],fa[N];
ll num[N],ln[N],rn[N],tag[N];
bool rev[N];ll st[N*10];
bool isroot(ll x)
{return tr[fa[x]][0]!=x&&tr[fa[x]][1]!=x;}
void pushup(ll x)
{
ll l=tr[x][0],r=tr[x][1];num[x]=1;
if(l)ln[x]=ln[l],num[x]+=num[l];
else ln[x]=v[x];
if(r)rn[x]=rn[r],num[x]+=num[r];
else rn[x]=v[x];
if(l&&rn[l]==v[x])num[x]--;
if(r&&ln[r]==v[x])num[x]--;
}
void pushdown(ll x)
{
ll l=tr[x][0],r=tr[x][1];
if(rev[x])
{
rev[x]^=1,rev[l]^=1,rev[r]^=1;
swap(tr[l][0],tr[l][1]);swap(ln[l],rn[l]);
swap(tr[r][0],tr[r][1]);swap(ln[r],rn[r]);
}
if(tag[x]!=-1)
{
if(l)ln[l]=rn[l]=tag[x],num[l]=1,tag[l]=tag[x],v[l]=tag[x];
if(r)ln[r]=rn[r]=tag[x],num[r]=1,tag[r]=tag[x],v[r]=tag[x];
tag[x]=-1;
}
}
void rotate(ll x)
{
ll y=fa[x],z=fa[y],l,r;
l=tr[y][1]==x;r=l^1;
if(!isroot(y))tr[z][tr[z][1]==y]=x;
fa[x]=z,fa[y]=x,fa[tr[x][r]]=y;
tr[y][l]=tr[x][r],tr[x][r]=y;
pushup(y);pushup(x);
}
void splay(ll x)
{
ll top=0;st[++top]=x;
for(ll i=x;!isroot(i);i=fa[i])
st[++top]=fa[i];
while(top)pushdown(st[top--]);
while(!isroot(x))
{
ll y=fa[x],z=fa[y];
if(!isroot(y))
{
if(tr[y][0]==x^tr[z][0]==y)
rotate(x);
else rotate(y);
}rotate(x);
}
}
void access(ll x)
{
ll t=0;
while(x)
{
splay(x);
tr[x][1]=t;
pushup(x);
t=x,x=fa[x];
}
}
void rever(ll x)
{
access(x);
splay(x);
rev[x]^=1;
swap(tr[x][0],tr[x][1]);
swap(ln[x],rn[x]);
}
void link(ll x,ll y)
{
rever(x);
fa[x]=y;
splay(x);
}
void cut(ll x,ll y){
rever(x);
access(y);
splay(y);
tr[y][0]=fa[x]=0;
pushup(y);
}
void split(ll x,ll y){
rever(x);
access(y);
splay(y);
}
int main()
{
// freopen("in.txt","r",stdin);
// freopen("my.txt","w",stdout);
n=read(),m=read();
for(ll i=1;i<=n;i++)
{
v[i]=c[i]=read();
ln[i]=rn[i]=c[i];
num[i]=1;tag[i]=-1;
}
for(ll i=1;i<n;i++)
{
ll x=read(),y=read();
link(x,y);
}char ch[10];
for(ll i=1;i<=m;i++)
{
scanf("%s",ch);
ll x=read(),y=read();
if(ch[0]=='C')
{
ll z=read();
split(x,y);tag[y]=v[y]=z;
ln[y]=rn[y]=z;num[y]=1;
}
else
{
split(x,y);
printf("%lld\n",num[y]);
}
}
}
2.树链剖分:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<string>
#include<climits>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define N 100010
#define M 1001000
using namespace std;
int read()
{
int x=0,f=1;char ch;
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,m,c[N],col[N];
int head[N],pos;
struct edge{int to,next;}e[N<<1];
int tag[N],bl[N];
struct node{int l,r,num,lc,rc,tag;}t[M];
void add(int a,int b)
{pos++;e[pos].to=b,e[pos].next=head[a],head[a]=pos;}
int dep[N],p[N],size[N];
int tme,id[N];
void dfs1(int u)
{
size[u]=1,dep[u]=dep[p[u]]+1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==p[u])continue;
p[v]=u;dfs1(v);
size[u]+=size[v];
}
}
void dfs2(int u,int tp)
{
id[u]=++tme;bl[u]=tp;int k=0;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==p[u])continue;
if(size[v]>size[k])k=v;
}if(!k)return ;dfs2(k,tp);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==p[u]||v==k)continue;
dfs2(v,v);
}
}
void pushup(int k)
{
if(t[k].l==t[k].r)return;
t[k].num=t[k<<1].num+t[k<<1|1].num;
t[k].lc=t[k<<1].lc,t[k].rc=t[k<<1|1].rc;
if(t[k<<1].rc==t[k<<1|1].lc)t[k].num--;
}
void pushdown(int k)
{
int l=t[k].l,r=t[k].r;
if(l==r)return;
l=k<<1,r=k<<1|1;
if(t[k].tag!=-1)
{
t[l].num=t[r].num=1;
t[l].lc=t[l].rc=t[k].tag;
t[r].lc=t[r].rc=t[k].tag;
t[l].tag=t[r].tag=t[k].tag;
t[k].tag=-1;
}
}
void build(int u,int l,int r)
{
t[u].l=l,t[u].r=r,t[u].tag=-1;
if(l==r){t[u].num=1,t[u].lc=t[u].rc=col[l];return;}
int mid=(l+r)>>1;build(u<<1,l,mid);build(u<<1|1,mid+1,r);
pushup(u);
}
void change(int u,int x,int y,int val)
{
int l=t[u].l,r=t[u].r;pushdown(u);
if(x<=l&&y>=r)
{
t[u].tag=val;t[u].num=1;
t[u].lc=t[u].rc=val;
return;
}int mid=(l+r)>>1;
if(y<=mid)change(u<<1,x,y,val);
else if(x>mid)change(u<<1|1,x,y,val);
else {change(u<<1,x,mid,val);change(u<<1|1,mid+1,y,val);}
pushup(u);
}
node ask(int u,int x,int y)
{
int l=t[u].l,r=t[u].r;pushdown(u);
if(x<=l&&y>=r)return t[u];int mid=(l+r)>>1;
if(y<=mid)return ask(u<<1,x,y);
else if(x>mid)return ask(u<<1|1,x,y);
else
{
node tmp=ask(u<<1,x,mid);
node txp=ask(u<<1|1,mid+1,y);
node ret;ret.num=tmp.num+txp.num;
if(tmp.rc==txp.lc)ret.num--;
ret.l=tmp.l,ret.r=txp.r;
ret.lc=tmp.lc,ret.rc=txp.rc;
return ret;
}
}
int getc(int u,int x)
{
int l=t[u].l,r=t[u].r;
pushdown(u);int mid=(l+r)>>1;
if(l==r)return t[u].lc;
if(x<=mid)return getc(u<<1,x);
else return getc(u<<1|1,x);
}
int solve_num(int x,int y)
{
int ret=0,last=-1;
while(bl[x]!=bl[y])
{
if(dep[bl[x]]<dep[bl[y]])
swap(x,y);
node tmp=ask(1,id[bl[x]],id[x]);
ret+=tmp.num;
if(getc(1,id[bl[x]])==getc(1,id[p[bl[x]]]))ret--;
x=p[bl[x]];
}if(id[x]>id[y])swap(x,y);
node tmp=ask(1,id[x],id[y]);
ret+=tmp.num;return ret;
}
void solve_ch(int x,int y,int cl)
{
while(bl[x]!=bl[y])
{
if(dep[bl[x]]<dep[bl[y]])
swap(x,y);
change(1,id[bl[x]],id[x],cl);
x=p[bl[x]];
}if(id[x]>id[y])swap(x,y);
change(1,id[x],id[y],cl);
}
void solve()
{
char ch[10];
for(int i=1;i<=m;i++)
{
scanf("%s",ch);
int x=read(),y=read();
if(ch[0]=='C')
solve_ch(x,y,read());
else printf("%d\n",solve_num(x,y));
}
}
int main()
{
// freopen("in.txt","r",stdin);
// freopen("my.txt","w",stdout);
n=read();m=read();
for(int i=1;i<=n;i++)
c[i]=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y);add(y,x);
}dfs1(1);dfs2(1,1);
for(int i=1;i<=n;i++)
col[id[i]]=c[i];
build(1,1,n);
solve();
}