题意:给出一棵树,每个节点初始都有一个颜色值,有m次操作,1操作:输入id,c。第id个节点以及其子树都被染成c。
2操作:id,输出id子树中不同颜色的种数。
解法:先跑一次DFS,第id个节点就控制DFS序为l,r的区间。那么就转换为了线段树的区间操作。
怎么记录有多少种颜色?题意上说颜色种数不超过60。那么第c种颜色就对应二进制的第c位,两段区间合后的颜色为他们两段区间颜色值的“或”。最后看这个颜色值二进制中有多少个1。
#include<iostream>
#include<string>
#include<stdio.h>
#include<string.h>
#include<vector>
#include<math.h>
#include<queue>
#include<map>
#include<set>
#include<algorithm>
using namespace std;
#define MAXN 4*100005
#define LL long long
#define INF 0x3f7f7f7f
const double eps = 1e-10;
struct node
{
int l,r,c;
LL val;
}tr[MAXN<<2];
struct point
{
int l,r,c;
}point[MAXN];
struct Edge
{
int en,next;
}E[MAXN*2];
int p[MAXN],num;
int n,m,step;
int vis[MAXN],dfn[MAXN];
//dfn[i],DFS序为i对应的节点编号
void add(int st,int en)
{
E[num].en=en;
E[num].next=p[st];
p[st]=num++;
}
void dfs(int id)
{
point[id].l=++step;
dfn[step]=id;
vis[id]=1;
int i;
for(i=p[id];i!=-1;i=E[i].next)
{
int en=E[i].en;
if(!vis[en])
dfs(en);
}
point[id].r=step;
//cout<<"id:"<<id<<" "<<point[id].l<<" "<<point[id].r<<endl;
}
void pushup(int id)
{
tr[id].val=tr[id*2].val|tr[id*2+1].val;//两段区间合并后的颜色值
}
void pushdown(int id)
{
tr[id].c=0;
tr[id*2].c=tr[id*2+1].c=1;
tr[id*2].val=tr[id*2+1].val=tr[id].val;
}
void build(int id,int l,int r)
{
tr[id].l=l;
tr[id].r=r;
tr[id].c=0;
if(l==r)
{
tr[id].val=(LL)((LL)1<<(LL)point[dfn[l]].c);
}
else
{
int mid=(l+r)/2;
build(id*2,l,mid);
build(id*2+1,mid+1,r);
pushup(id);
}
}
void update(int id,int l,int r,LL val)
{
if(l<=tr[id].l&&tr[id].r<=r)
{
tr[id].val=(LL)((LL)1<<val);
tr[id].c=1;
}
else
{
int mid=(tr[id].l+tr[id].r)>>1;
if(tr[id].c)
pushdown(id);
if(r<=mid)update(id*2,l,r,val);
else if(l>mid)update(id*2+1,l,r,val);
else
{
update(id*2,l,r,val);
update(id*2+1,l,r,val);
}
pushup(id);
}
}
LL query(int id,int l,int r)
{
if(l<=tr[id].l&&tr[id].r<=r)
{
return tr[id].val;
}
else
{
int mid=(tr[id].l+tr[id].r)>>1;
if(tr[id].c)
pushdown(id);
LL ans;
if(r<=mid)return query(id*2,l,r);
else if(l>mid)return query(id*2+1,l,r);
else
{
ans=query(id*2,l,r);
ans=ans|query(id*2+1,l,r);
return ans;
}
}
}
int fun(LL x)
{
int ans=0;
for(LL i=0;i<=64;i++)
{
LL temp=(LL)((LL)1<<(LL)i);
if(temp&x)
ans++;
}
return ans;
}
int main()
{
int i;
while(scanf("%d%d",&n,&m)!=EOF)
{
num=0;
memset(p,-1,sizeof(p));
for(i=1;i<=n;i++)
scanf("%d",&point[i].c);
for(i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
step=0;
memset(vis,0,sizeof(vis));
dfs(1);
build(1,1,n);
for(i=1;i<=m;i++)
{
int opt;
int id,co;
scanf("%d",&opt);
if(opt==1)
{
scanf("%d%d",&id,&co);
update(1,point[id].l,point[id].r,co);
}
else
{
scanf("%d",&id);
LL x=query(1,point[id].l,point[id].r);
//cout<<i<<" ::: "<<x<<endl;
printf("%d\n",fun(x));
}
}
}
return 0;
}