题目链接:https://cn.vjudge.net/problem/CodeForces-620E
题目大意
一棵树,1 z x表示把以z为根的子树节点(包括z)的颜色,更新成x,2 z表示询问以z为根的子树的颜色种类。
分析
首先要dfs序,把子树区间弄成连续的区间,放入线段树内。
因为一共60种颜色,我们可以用一个long long的二进制数来表示,颜色为1,就把第一位设成1,颜色为3,就把第3位设成3,查询时,把答案进行或运算,所以最后统计答案有多少个1就可以了。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 4e5+5;
int n, m, cnt, tot;
int c[N], head[2*N], in[N], out[N], num[N];
struct Edge{
int to, nxt;
}edge[2*N];
struct node{
int l, r;
ll val, lazy;
}tr[N<<2];
void add(int a, int b)
{
edge[cnt].to = b;
edge[cnt].nxt = head[a];
head[a] = cnt++;
}
void pushup(int m)
{
tr[m].val = tr[m<<1].val | tr[m<<1|1].val;
}
void pushdown(int m)
{
if(tr[m].lazy)
{
tr[m<<1].val = tr[m].lazy;
tr[m<<1|1].val = tr[m].lazy;
tr[m<<1].lazy = tr[m].lazy;
tr[m<<1|1].lazy = tr[m].lazy;
tr[m].lazy = 0;
}
}
void build(int m, int l, int r)
{
tr[m].l = l;
tr[m].r = r;
tr[m].val = 0;
tr[m].lazy = 0;
if(l == r)
{
tr[m].val = (1LL << (num[l] - 1));
return ;
}
int mid = (l + r) >> 1;
build(m<<1, l, mid);
build(m<<1|1, mid + 1, r);
pushup(m);
}
void updata(int m, int l, int r, int w)
{
if(tr[m].l >= l && tr[m].r <= r)
{
tr[m].val = (1LL << (w - 1));
tr[m].lazy = (1LL << (w - 1));
return ;
}
pushdown(m);
int mid = (tr[m].l + tr[m].r) >> 1;
if(l <= mid) updata(m<<1, l, r, w);
if(r > mid) updata(m<<1|1, l, r, w);
pushup(m);
}
ll ask(int m, int l, int r)
{
if(tr[m].l >= l && tr[m].r <= r) return tr[m].val;
pushdown(m);
int mid = (tr[m].l + tr[m].r) >> 1;
ll res = 0;
if(l <= mid) res |= ask(m<<1, l, r);
if(r > mid) res |= ask(m<<1|1, l, r);
return res;
}
void dfs(int x, int pre)
{
in[x] = ++tot;
num[tot] = c[x];
for(int i = head[x]; i != -1; i = edge[i].nxt)
{
int j = edge[i].to;
if(j == pre) continue;
dfs(j, x);
}
out[x] = tot;
}
int main()
{
cnt = tot = 0;
memset(head, -1, sizeof head);
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%d", &c[i]);
for(int i = 1; i < n; i++)
{
int a, b;
scanf("%d %d", &a, &b);
add(a, b);
add(b, a);
}
dfs(1, 0);
build(1, 1, tot);
while(m--)
{
int op, z, w;
scanf("%d", &op);
if(op == 1)
{
scanf("%d %d", &z, &w);
updata(1, in[z], out[z], w);
}
else
{
scanf("%d", &z);
ll tmp = ask(1, in[z], out[z]);
int ans = 0;
while(tmp)
{
tmp -= (tmp & (-tmp));
ans++;
}
printf("%d\n", ans);
}
}
return 0;
}