题目大意:给一颗树,根节点为1,每个节点都有颜色(<=60),支持将一个子树中的所有节点都染成颜色c的操作和询问以v为根节点的子树中有多少种不同的颜色。
对子树操作,dfs序+线段树 (对树上的链操作用树链剖分)
表示颜色(col<=60) 可以用位运算(bitmasks)
两个区间颜色数的合并为 or ,总颜色统计为二进制上出现的1的个数
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
using namespace std;
#define MAXN 400100
#define LL long long
int n,m,l,x,y,tot,co,t,v;
LL ans;
int ans1;
int st[MAXN],ed[MAXN],c[MAXN];
int s[2*MAXN];
bool vis[MAXN];
struct node{
int l,r,flag;
LL col;
}tree[5*MAXN];
struct point{
int y,next;
}edge[MAXN*2];
int head[MAXN];
void add(int x,int y)
{
l++;
edge[l].y=y;
edge[l].next=head[x];
head[x]=l;
}
void dfs(int x)
{
st[x]=++tot;
s[tot]=x;
vis[x]=1;
for (int i=head[x];i!=0;i=edge[i].next)
if(vis[edge[i].y]==0)
dfs(edge[i].y);
ed[x]=++tot;
s[tot]=x;
}
void pushup(int p)
{
tree[p].col=tree[p<<1].col | tree[p<<1|1].col;
}
void pushdown(int p)
{
if (tree[p].flag>0)
{
tree[p<<1].flag=tree[p<<1|1].flag=tree[p].flag;
tree[p<<1].col=tree[p<<1|1].col=tree[p].col;
tree[p].flag=0;
}
}
void build(int p,int l,int r)
{
tree[p].l=l;
tree[p].r=r;
tree[p].flag=0;
if (l==r) {
tree[p].col=(LL)1 << (LL)c[s[l]];
return ;
}
int mid=(l+r) >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
pushup(p);
}
void change(int p,int l,int r,int c)
{
if (tree[p].l==l && tree[p].r==r)
{
tree[p].col=(LL)1 << (LL)c;
tree[p].flag=c;
return ;
}
pushdown(p);
int mid=(tree[p].l+tree[p].r) >> 1;
if (r<=mid) change(p<<1, l, r, c);
if (l>mid) change(p<<1|1, l, r, c);
if (l<=mid && r>mid)
{
change(p<<1,l,mid,c);
change(p<<1|1,mid+1,r,c);
}
pushup(p);
}
LL query(int p,int l,int r)
{
if (tree[p].l==l && tree[p].r==r)
{
return tree[p].col;
}
pushdown(p);
int mid=(tree[p].l+tree[p].r) >> 1;
if (r<=mid) return query(p<<1,l,r);
if (l>mid) return query(p<<1|1,l,r);
if (l<=mid && r>mid)
{
LL s1=query(p<<1,l,mid);
LL s2=query(p<<1|1,mid+1,r);
return s1|s2;
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i=1;i<=n;i++) scanf("%d", &c[i]);
memset(head,0,sizeof(head));
memset(vis,0,sizeof(vis));
for (int i=1;i<n;i++)
{
scanf("%d%d", &x, &y);
add(x,y);
add(y,x);
}
dfs(1);
build(1,1,2*n);
for (int i=1;i<=m;i++)
{
scanf("%d%d", &t, &v);
if (t==1)
{
scanf("%d", &co);
change(1,st[v],ed[v],co);
}
else
{
ans=(LL)query(1,st[v],ed[v]);
ans1=0;
while (ans>0)
{
ans1+=ans & 1;
ans>>=1;
}
printf("%d\n", ans1);
}
}
}