Description
给定一棵以1为根的有根树,点可能是黑色或白色,操作如下。
- 选定一个点x,将x的子树中所有到x的距离为奇数的点的颜色反转。
- 选定一个点x,将点x的颜色反转。
- 选定一个点x,询问所有黑点y(包括点x)与点x的lca(最近公共祖先)的和。
Input
第一行两个正整数n,m,表示树的节点数和操作数。 第二行n个整数c[i],若c[i]=1说明点i是黑色,否则为白色。
接下来n-1行,每行两个正整数u,v表示一条边。 接下来m行,每行两个正整数op,x描述一个操作,op为操作种类,与题目描述所述相同。n,m<=200000
Output
对于每个询问输出答案。
Sample Input
3 5
1 1 1
1 2
1 3
3 2
2 1
3 2
1 1
3 2
Sample Output
4
3
0
题解
真是累啊…
写了一晚上
调了一下午+半个晚上
不过收获还是很大啊
比如第一次实现标记永久化
以及明示了自己的一个小错误,节点存储的答案不是打了标记之后的答案
考虑朴素做法
设x到根的链是 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn
设siz[x]表示x的子树中黑点个数
显然答案是 ∑ x i ∗ ( s i z [ x i ] − s i z [ x i + 1 ] ) \sum x_i*(siz[x_i]-siz[x_{i+1}]) ∑xi∗(siz[xi]−siz[xi+1])
发现贡献分布在一条链上,考虑树剖
这里就需要一点小技巧了
每个节点并不保存单点权值,而保存除去他的重儿子子树后,所有黑点*他的编号的和
即保存去除了重儿子子树后他对答案的贡献
记录 s u m [ 2 ] [ 2 ] [ n o w ] sum[2][2][now] sum[2][2][now]表示dep%2=i,col%2=j的状态下线段树上这一段的答案总和
记录 n u m [ 2 ] [ 2 ] [ n o w ] num[2][2][now] num[2][2][now]表示dep%2=i,col%2=j的状态下线段树上这一段的黑点总和
记录 i n v [ 2 ] [ n o w ] inv[2][now] inv[2][now]表示dep%2=i的状态下这一段颜色块的翻转情况
记录 t a g [ 2 ] [ n o w ] tag[2][now] tag[2][now]表示dep%2=i的状态下这一段答案的翻转情况
对于单点修改颜色
我们找到这个点,改变他的inv标记并记录在某个dep,某个col更改了的贡献
之后向上跳重链,每条重链的父亲的答案会被影响,暴力改掉
对于子树翻转颜色
我们找到这个点的子树这段区间,改变inv标记及tag标记,同样记录变化了多少的贡献
向上跳重链,每次暴力统计子树的贡献分别改变了多少,暴力改掉
对于单点查询
找到这个点,先暴力统计子树颜色数量并计入答案,再向上跳重链
对于除去重链的最底端这一段,可以直接计入答案
在重链的最底端,显然他由一条轻链与另一条重链连接
设最底端为x
暴力统计x的子树贡献并去除下面跳上来那个点的子树贡献
计入答案即可
8k+,感觉以后也不想写了呀…
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#include<map>
#define LL long long
#define mp(x,y) make_pair(x,y)
#define lc now<<1
#define rc now<<1|1
using namespace std;
inline int read()
{
int f=1,x=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void write(LL x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)write(x/10);
putchar(x%10+'0');
}
inline void pr1(int x){write(x);printf(" ");}
inline void pr2(LL x){write(x);puts("");}
struct pt
{
LL a1,b1,a2,b2;
pt(){a1=b1=a2=b2=0;}
pt(LL _a1,LL _b1,LL _a2,LL _b2){a1=_a1;b1=_b1;a2=_a2;b2=_b2;}
}w[410000];
struct node{int x,y,next;}a[410000];int len,last[210000];
void ins(int x,int y){len++;a[len].x=x;a[len].y=y;a[len].next=last[x];last[x]=len;}
int fa[210000],dep[210000],tot[210000],son[210000],n,m;
void pre_tree_node(int x)
{
tot[x]=1;son[x]=0;
for(int k=last[x];k;k=a[k].next)
{
int y=a[k].y;
if(y!=fa[x])
{
fa[y]=x;dep[y]=dep[x]+1;
pre_tree_node(y);
if(tot[y]>tot[son[x]])son[x]=y;
tot[x]+=tot[y];
}
}
}
int ys[210000],fac[210000],z,top[210000],down[210000];
void pre_tree_edge(int x,int tp)
{
ys[x]=++z;fac[z]=x;top[x]=tp;
if(son[x])pre_tree_edge(son[x],tp),down[x]=down[son[x]];
else down[x]=x;
for(int k=last[x];k;k=a[k].next)
if(a[k].y!=fa[x]&&a[k].y!=son[x])pre_tree_edge(a[k].y,a[k].y);
}
LL sum[2][2][210000*4];//dep%2=i col=j 答案
int col[210000],inv[2][210000*4];//这个点的颜色是否相反
int tag[2][210000*4];//dep%2=i 这段区间是否相反
LL num[2][2][210000*4];//dep%2=i col=j 节点个数
//0黑1白
void pushup(int now)
{
for(int i=0;i<=1;i++)//col
for(int j=0;j<=1;j++)//dep
{
int c1=i^tag[j][now];
sum[j][i][now]=sum[j][c1][lc]+sum[j][c1][rc];
num[j][i][now]=num[j][c1][lc]+num[j][c1][rc];
}
}
int c1;
void modify1(int now,int l,int r,int p)
{
c1^=inv[dep[fac[p]]%2][now];
if(l==r)
{
inv[dep[fac[p]]%2][now]^=1;c1^=1;
int c=col[fac[p]]^inv[dep[fac[p]]%2][now],x=fac[p],p1=dep[x]%2;
num[p1][c][now]++;num[p1][c^1][now]--;
sum[p1][c][now]=num[p1][c][now]*x;sum[p1][c^1][now]=num[p1][c^1][now]*x;
return ;
}
int mid=(l+r)/2;
if(p<=mid)modify1(lc,l,mid,p);
else modify1(rc,mid+1,r,p);
pushup(now);
}
void modify2(int now,int l,int r,int p,int de,int a1,int a2)//黑点+a1,白点+a2
{
if(l==r)
{
num[de][0][now]+=a1;num[de][1][now]+=a2;
sum[de][0][now]=num[de][0][now]*fac[p];
sum[de][1][now]=num[de][1][now]*fac[p];
return ;
}
if(tag[de][now])swap(a1,a2);
int mid=(l+r)/2;
if(p<=mid)modify2(lc,l,mid,p,de,a1,a2);
else modify2(rc,mid+1,r,p,de,a1,a2);
pushup(now);
}
void mulchange1(int x)
{
c1=col[x];modify1(1,1,z,ys[x]);
int u1=0,u2=0;
if(!c1)u1=1,u2=-1;else u2=1,u1=-1;
int dd=dep[x]%2;
while(top[x]!=1)
{
int u=fa[top[x]];
modify2(1,1,z,ys[u],dd,u1,u2);
x=u;
}
}
//----------------------------------op=2
struct ph
{
int a,b;
ph(){a=b=0;}
ph(int _a,int _b){a=_a;b=_b;}
};
void modify4(int now,int l,int r,int ql,int qr,int de)
{
if(l==ql&&r==qr)
{
tag[de][now]^=1;inv[de][now]^=1;
swap(sum[de][0][now],sum[de][1][now]);
swap(num[de][0][now],num[de][1][now]);
return ;
}
int mid=(l+r)/2;
if(qr<=mid)modify4(lc,l,mid,ql,qr,de);
else if(mid+1<=ql)modify4(rc,mid+1,r,ql,qr,de);
else modify4(lc,l,mid,ql,mid,de),modify4(rc,mid+1,r,mid+1,qr,de);
pushup(now);
}
void modify5(int now,int l,int r,int p,int a1,int b1,int a2,int b2)//=0 黑+=a1 =1 黑+=a2
{
if(l==r)
{
num[0][0][now]+=a1;num[0][1][now]+=b1;
sum[0][0][now]=num[0][0][now]*fac[p];
sum[0][1][now]=num[0][1][now]*fac[p];
num[1][0][now]+=a2;num[1][1][now]+=b2;
sum[1][0][now]=num[1][0][now]*fac[p];
sum[1][1][now]=num[1][1][now]*fac[p];
return ;
}
if(tag[0][now])swap(a1,b1);
if(tag[1][now])swap(a2,b2);
int mid=(l+r)/2;
if(p<=mid)modify5(lc,l,mid,p,a1,b1,a2,b2);
else modify5(rc,mid+1,r,p,a1,b1,a2,b2);
pushup(now);
}
pt query1(int now,int l,int r,int ql,int qr)
{
if(ql>qr)return pt(0,0,0,0);
if(l==ql&&r==qr)
{
int a1,b1,a2,b2;
a1=num[0][0][now],b1=num[0][1][now];
a2=num[1][0][now],b2=num[1][1][now];
return pt(a1,b1,a2,b2);
}
int mid=(l+r)/2;
if(qr<=mid)
{
pt ret=query1(lc,l,mid,ql,qr);
if(tag[0][now])swap(ret.a1,ret.b1);
if(tag[1][now])swap(ret.a2,ret.b2);
return ret;
}
else if(mid+1<=ql)
{
pt ret=query1(rc,mid+1,r,ql,qr);
if(tag[0][now])swap(ret.a1,ret.b1);
if(tag[1][now])swap(ret.a2,ret.b2);
return ret;
}
else
{
int a1,b1,a2,b2;
pt r1,r2;
r1=query1(lc,l,mid,ql,mid);r2=query1(rc,mid+1,r,mid+1,qr);
if(tag[0][now])a1=r1.b1+r2.b1,b1=r1.a1+r2.a1;
else a1=r1.a1+r2.a1,b1=r1.b1+r2.b1;
if(tag[1][now])a2=r1.b2+r2.b2,b2=r1.a2+r2.a2;
else a2=r1.a2+r2.a2,b2=r1.b2+r2.b2;
return pt(a1,b1,a2,b2);
}
}
void ad(pt &x,pt y)
{
x.a1+=y.a1;x.a2+=y.a2;
x.b1+=y.b1;x.b2+=y.b2;
}
pt getsum(int x)
{
pt tmp;
while(x!=0)
{
ad(tmp,query1(1,1,z,ys[x],ys[down[x]]));
x=son[down[x]];
}
return tmp;
}
void mulchange2(int x)
{
int de=(dep[x]%2+1)%2;
pt tmp=getsum(x);
modify4(1,1,z,ys[x],ys[x]+tot[x]-1,de);
pt g1=getsum(x);
int u1,u2;
if(de)u1=g1.a2-tmp.a2,u2=g1.b2-tmp.b2;
else u1=g1.a1-tmp.a1,u2=g1.b1-tmp.b1;
while(top[x]!=1)
{
int u=fa[top[x]];
if(de)modify5(1,1,z,ys[u],0,0,u1,u2);
else modify5(1,1,z,ys[u],u1,u2,0,0);
x=u;
}
}
//--------------------------op=1
pt query2(int now,int l,int r,int ql,int qr)
{
if(l==ql&&r==qr)
{
LL a1,b1,a2,b2;
a1=sum[0][0][now],b1=sum[0][1][now];
a2=sum[1][0][now],b2=sum[1][1][now];
return pt(a1,b1,a2,b2);
}
int mid=(l+r)/2;
if(qr<=mid)
{
pt ret=query2(lc,l,mid,ql,qr);
if(tag[0][now])swap(ret.a1,ret.b1);
if(tag[1][now])swap(ret.a2,ret.b2);
return ret;
}
else if(mid+1<=ql)
{
pt ret=query2(rc,mid+1,r,ql,qr);
if(tag[0][now])swap(ret.a1,ret.b1);
if(tag[1][now])swap(ret.a2,ret.b2);
return ret;
}
else
{
int a1,b1,a2,b2;
pt r1,r2;
r1=query2(lc,l,mid,ql,mid);r2=query2(rc,mid+1,r,mid+1,qr);
if(tag[0][now])a1=r1.b1+r2.b1,b1=r1.a1+r2.a1;
else a1=r1.a1+r2.a1,b1=r1.b1+r2.b1;
if(tag[1][now])a2=r1.b2+r2.b2,b2=r1.a2+r2.a2;
else a2=r1.a2+r2.a2,b2=r1.b2+r2.b2;
return pt(a1,b1,a2,b2);
}
}
pt solve(int x)
{
pt tmp,ret;int y;
ret=tmp=getsum(x);
ret.a1*=x;ret.a2*=x;
if(x!=top[x])
{
pt u1=query1(1,1,z,ys[top[x]],ys[fa[x]]),u2=query2(1,1,z,ys[top[x]],ys[fa[x]]);
ad(tmp,u1);ad(ret,u2);
}
x=fa[top[x]];
while(x!=0)
{
pt u1=getsum(x);
//ad(tmp,u1);
ret.a1+=(u1.a1-tmp.a1)*x;ret.a2+=(u1.a2-tmp.a2)*x;
tmp=u1;
if(x!=top[x])
{
pt u3=query1(1,1,z,ys[top[x]],ys[fa[x]]),u2=query2(1,1,z,ys[top[x]],ys[fa[x]]);
ad(tmp,u3);ad(ret,u2);
}
x=fa[top[x]];
}
return ret;
}
//-------------op=3
int f1[210000][2],g1[210000][2],f2[210000][2],g2[210000][2];
void treedp(int x)
{
f1[x][dep[x]%2]=g1[x][dep[x]%2]=(col[x]==0);
f2[x][dep[x]%2]=g2[x][dep[x]%2]=(col[x]==1);
for(int k=last[x];k;k=a[k].next)
{
int y=a[k].y;
if(y!=fa[x])
{
treedp(y);
g1[x][0]+=g1[y][0];g1[x][1]+=g1[y][1];
g2[x][0]+=g2[y][0];g2[x][1]+=g2[y][1];
if(y!=son[x])f1[x][0]+=g1[y][0],f1[x][1]+=g1[y][1],f2[x][0]+=g2[y][0],f2[x][1]+=g2[y][1];
}
}
}
void debug(int now,int l,int r,int p)
{
c1^=inv[dep[fac[p]]%2][now];
if(l==r)return ;
int mid=(l+r)/2;
if(p<=mid)debug(lc,l,mid,p);
else debug(rc,mid+1,r,p);
//pushup(now);
}
int main()
{
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
n=read();m=read();
for(int i=1;i<=n;i++)col[i]=read(),col[i]^=1;
for(int i=1;i<n;i++)
{
int x=read(),y=read();
ins(x,y);ins(y,x);
}
pre_tree_node(1);
pre_tree_edge(1,1);
treedp(1);
for(int i=1;i<=n;i++)
modify5(1,1,z,ys[i],f1[i][0],f2[i][0],f1[i][1],f2[i][1]);
while(m--)
{
int op=read(),x=read();
/*printf("types %d %d ",op,x);
printf("CHECKER :");
for(int i=1;i<=n;i++)
{
pt tmp=solve(i);
pr1(tmp.a1+tmp.a2);
}
printf(" ");
printf("COLOUR :");
for(int i=1;i<=n;i++)
{
c1=col[i];debug(1,1,z,ys[i]);
pr1(c1);
}
pt tmp=getsum(1);
puts("");
printf(" a1:%d b1:%d a2:%d b2:%d",tmp.a1,tmp.b1,tmp.a2,tmp.b2);
puts("");*/
if(op==1)
mulchange2(x);
else if(op==2)
mulchange1(x);
else
{
pt tmp=solve(x);
pr2(tmp.a1+tmp.a2);
}
}
return 0;
}