看了题解才懂系列..
利用dfs序将树形结构化为线性结构 先序遍历一棵树得到的序列中每个父节点之后紧跟的都是其子节点 只要知道每个节点有多少子节点 就知道了需要修改的区间的范围
剩下的就是状态压缩了 60种颜色用longlong存即可
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct nodeI
{
int v;
int next;
};
struct nodeII
{
int l;
int r;
ll val;
ll laz;
};
nodeI edge[800010];
nodeII tree[2000010];
ll clr[400010],pre[100];
int first[400010],id[400010],mp[400010],sum[400010];
int n,q,num;
void addedge(int u,int v)
{
edge[num].v=v;
edge[num].next=first[u];
first[u]=num++;
return;
}
void dfs(int fa,int u)
{
int i,v;
id[u]=++num;
for(i=first[u];i!=-1;i=edge[i].next)
{
v=edge[i].v;
if(v!=fa)
{
dfs(u,v);
sum[u]+=(sum[v]+1);
}
}
return;
}
void pushup(int cur)
{
tree[cur].val=tree[cur*2].val|tree[cur*2+1].val;
return;
}
void pushdown(int cur)
{
if(tree[cur].laz)
{
tree[cur*2].val=tree[cur].laz;
tree[cur*2].laz=tree[cur].laz;
tree[cur*2+1].val=tree[cur].laz;
tree[cur*2+1].laz=tree[cur].laz;
tree[cur].laz=0;
}
return;
}
void build(int l,int r,int cur)
{
int m;
tree[cur].l=l;
tree[cur].r=r;
tree[cur].val=0;
tree[cur].laz=0;
if(l==r)
{
tree[cur].val=clr[mp[++num]];
return;
}
m=(l+r)/2;
build(l,m,cur*2);
build(m+1,r,cur*2+1);
pushup(cur);
return;
}
void update(int pl,int pr,ll val,int cur)
{
if(pl<=tree[cur].l&&tree[cur].r<=pr)
{
tree[cur].val=val;
tree[cur].laz=val;
return;
}
pushdown(cur);
if(pl<=tree[cur*2].r) update(pl,pr,val,cur*2);
if(pr>=tree[cur*2+1].l) update(pl,pr,val,cur*2+1);
pushup(cur);
return;
}
ll query(int pl,int pr,int cur)
{
ll res;
if(pl<=tree[cur].l&&tree[cur].r<=pr)
{
return tree[cur].val;
}
pushdown(cur);
res=0;
if(pl<=tree[cur*2].r) res|=query(pl,pr,cur*2);
if(pr>=tree[cur*2+1].l) res|=query(pl,pr,cur*2+1);
pushup(cur);
return res;
}
int judge(ll p)
{
int res;
res=0;
while(p>0)
{
if(p%2==1) res++;
p/=2;
}
return res;
}
void init()
{
int i;
pre[0]=1;
for(i=1;i<=60;i++)
{
pre[i]=pre[i-1]*2;
}
return;
}
int main()
{
ll tem;
int i,op,u,v,c;
init();
while(scanf("%d%d",&n,&q)!=EOF)
{
for(i=1;i<=n;i++)
{
scanf("%d",&c);
clr[i]=pre[c];
}
memset(first,-1,sizeof(first));
num=0;
for(i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
memset(sum,0,sizeof(sum));
num=0;
dfs(-1,1);
for(i=1;i<=n;i++)
{
mp[id[i]]=i;
}
num=0;
build(1,n,1);
while(q--)
{
scanf("%d",&op);
if(op==1)
{
scanf("%d%d",&u,&c);
update(id[u],id[u]+sum[u],pre[c],1);
}
else
{
scanf("%d",&u);
tem=query(id[u],id[u]+sum[u],1);
printf("%d\n",judge(tem));
}
}
}
return 0;
}