CF620E
题解:建议大家先学会这两道题目POJ2777和POJ3321,之后的都简单了
代码:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
int const N = 400000 + 10;
int n,m,tot,cnt;
int first[N],ne[2*N],to[2*N];
int col[N],vis[N],sz[N],id[N],ncol[N];
struct Node
{
int l,r;
ll col,lazy;
void updata(ll val){
col = lazy = val;
}
}node[N<<2];
void add(int x,int y){
ne[++tot] = first[x];
to[tot] = y;
first[x] = tot;
}
void Init(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&col[i]);
tot = 0;
memset(first,0,sizeof(first));
for(int i=1;i<=n-1;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
}
void DFS(int u,int fa){
ncol[id[u]=++tot] = col[u];
sz[u] = 1;
for(int i=first[u];i;i=ne[i]){
int v = to[i];
if(v == fa) continue;
DFS(v,u);
sz[u] += sz[v];
}
}
void push_up(int id){
node[id].col = node[id<<1].col | node[id<<1|1].col;
}
void push_down(int id){
ll val = node[id].lazy;
if(val){
node[id<<1].updata(val);
node[id<<1|1].updata(val);
node[id].lazy = 0;
}
}
void build(int id,int l,int r){
node[id].l = l, node[id].r = r;
node[id].lazy = 0;
if(l == r) node[id].col = 1ll<<ncol[r];
else{
int mid = (l + r) >> 1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
push_up(id);
}
}
void updata(int id,int L,int R,ll k){
int l = node[id].l, r = node[id].r;
if(L <= l && r <= R){
node[id].updata(k);
return;
}else{
int mid = (l + r) >> 1;
push_down(id);
if(L <= mid) updata(id<<1,L,R,k);
if(mid < R) updata(id<<1|1,L,R,k);
push_up(id);
}
}
ll query(int id,int L,int R){
int l = node[id].l, r = node[id].r;
if(L <=l && r <= R){
return node[id].col;
}
else{
int mid = (l + r) >> 1;
push_down(id);
ll res = 0; //必须初始化为0!!!
if(L <= mid) res = query(id<<1,L,R);
if(mid < R) res |= query(id<<1|1,L,R);
return res;
}
}
int count(ll x){
int cnt = 0;
while(x){
cnt++;
x -= (x&-x);
}
return cnt;
}
void solve(){
int t,x,y;
for(int i=1;i<=m;i++){
scanf("%d",&t);
if(t == 1){
scanf("%d%d",&x,&y);
updata(1,id[x],id[x]+sz[x]-1,1ll<<y);
}else{
scanf("%d",&x);
printf("%d\n",count(query(1,id[x],id[x]+sz[x]-1)));
}
}
}
int main(){
Init();
tot = 0;
DFS(1,0);
build(1,1,n);
solve();
return 0;
}