2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 8465 Solved: 3170
[ Submit][ Status][ Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
巨难受
终于调A了
调了好久QAQ
还是师兄给力
好久没写4000K的代码了
线段树的部分其实不难,合并的时候如果左区间的右端点颜色和右区间的左端点颜色相同
那么ans--(ans=左边的ans+右边的ans)
难点在树链剖分部分
我们发现合并的顺序和方向对答案是有影响的
所以我们要把u边的ans和v边的ans的分开存
然后讨论左右问题
发现我们在线段树里面得到的区间都是左边id小的,右边id大的
而id是dfs序,所以id小的深度一定较小
v这边也是一样的
所以我们每次都把新得到的区间放在之前的答案的左边
而当他们在一条重链上时
我们需要将新查询出的答案的左右端点分别跟ans1,ans2比较
重点还是要理解我们每次查询得到的区间,都是左端点高,右端点低
剩下的就可以推出来了(感谢师兄QAQ)
#include<cstdio>
#include<cstring>
const int N=1e5+7;
struct node
{
int l,r;
int sum,ll,rr;
int set;
void clear()
{
sum=0,ll=0,rr=0;
}
}e[N*4];
struct edgt
{
int to,next;
}s[N*2];
int cnt,first[N];
int size[N],id[N],dep[N],fa[N];
inline void insert(int u,int v)
{
s[++cnt]=(edgt){v,first[u]};first[u]=cnt;
s[++cnt]=(edgt){u,first[v]};first[v]=cnt;
}
#define lson ro<<1
#define rson ro<<1|1
inline void pushup(node &ro,node l,node r)
{
ro.ll=l.ll;
ro.rr=r.rr;
ro.sum=l.sum+r.sum;
if(l.rr==r.ll) ro.sum--;
}
void build(int ro,int l,int r)
{
e[ro].l=l;e[ro].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(lson,l,mid);build(rson,mid+1,r);
}
#define lazy e[ro].set
inline void pushdown(int ro)
{
if(lazy)
{
e[lson].ll=e[lson].rr=lazy;
e[rson].ll=e[rson].rr=lazy;
e[lson].sum=e[rson].sum=1;
e[lson].set=lazy;
e[rson].set=lazy;
lazy=0;
}
}
node query(int ro,int l,int r)
{
if(l<=e[ro].l&&e[ro].r<=r) return e[ro];
pushdown(ro);
int mid=(e[ro].l+e[ro].r)>>1;
if(r<=mid) return query(lson,l,r);
else if(l>mid) return query(rson,l,r);
else
{
node a=query(lson,l,mid);node b=query(rson,mid+1,r);
node ans;ans.clear();
pushup(ans,a,b);
return ans;
}
}
void modify(int ro,int l,int r,int k)
{
if(l<=e[ro].l&&e[ro].r<=r)
{
e[ro].ll=e[ro].rr=k;lazy=k;e[ro].sum=1;
return;
}
pushdown(ro);
int mid=(e[ro].l+e[ro].r)>>1;
if(l<=mid) modify(lson,l,r,k);
if(r>mid) modify(rson,l,r,k);
pushup(e[ro],e[lson],e[rson]);
}
void dfs1(int x)
{
size[x]=1;
for(int k=first[x];k;k=s[k].next)
{
if(s[k].to==fa[x]) continue;
fa[s[k].to]=x;
dep[s[k].to]=dep[x]+1;
dfs1(s[k].to);
size[x]+=size[s[k].to];
}
}
int sz;int top[N];
void dfs2(int x,int chain)
{
int i=0;
id[x]=++sz;top[x]=chain;
for(int k=first[x];k;k=s[k].next)
if(dep[s[k].to]>dep[x]&&size[s[k].to]>size[i]) i=s[k].to;
if(!i) return;
dfs2(i,chain);
for(int k=first[x];k;k=s[k].next)
if(dep[s[k].to]>dep[x]&&s[k].to!=i) dfs2(s[k].to,s[k].to);
}
inline void swap(int &a,int &b)
{
int tmp=a;a=b;b=tmp;
}
int qsum(int u,int v)
{
node ans1,ans2,a;ans1.clear(),ans2.clear();//lca左边和右边的答案
bool tag=1;bool flag1=0,flag2=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v),tag^=1;;
a=query(1,id[top[u]],id[u]);
if(tag)
{
if(flag1)pushup(ans1,a,ans1);
else ans1.sum=a.sum,ans1.ll=a.ll,ans1.rr=a.rr,flag1=1;
}
else
{
if(flag2)pushup(ans2,a,ans2);
else ans2.sum=a.sum,ans2.ll=a.ll,ans2.rr=a.rr,flag2=1;
}
u=fa[top[u]];
}
if(id[u]>id[v]) swap(u,v),tag^=1;
a=query(1,id[u],id[v]);
int ans=a.sum;
if(tag){
if(flag1){
ans+=ans1.sum;
if(ans1.ll==a.ll) ans--;
}
if(flag2){
ans+=ans2.sum;
if(ans2.ll==a.rr) ans--;
}
}
else{
if(flag1){
ans+=ans1.sum;
if(ans1.ll==a.rr) ans--;
}
if(flag2){
ans+=ans2.sum;
if(ans2.ll==a.ll) ans--;
}
}
return ans;
}
void qset(int u,int v,int k)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
modify(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(id[u]>id[v]) swap(u,v);
modify(1,id[u],id[v],k);
}
int vi[N];
void change(int ro,int x)
{
if(e[ro].l==e[ro].r)
{
e[ro].sum=1;e[ro].ll=e[ro].rr=vi[x];return;
}
int mid=(e[ro].l+e[ro].r)>>1;
if(id[x]<=mid) change(lson,x);
else change(rson,x);
pushup(e[ro],e[lson],e[rson]);
}
int main()
{
int n,m,u,v;
scanf("%d %d",&n,&m);
build(1,1,n);
for(int i=1;i<=n;i++) scanf("%d",&vi[i]);
for(int i=1;i<n;i++) scanf("%d %d",&u,&v),insert(u,v);
dfs1(1);dfs2(1,1);
for(int i=1;i<=n;i++) change(1,i);
char a[2];
int q;
for(int i=1;i<=m;i++)
{
scanf("%s %d %d",a,&u,&v);
if(a[0]=='C') scanf("%d",&q),qset(u,v,q);
else printf("%d\n",qsum(u,v));
}
return 0;
}