题目大意
给你一棵树,要求支持两种操作:
-
x c
,给 x x x 的子树染上颜色 c c c(一个节点可以有多个颜色)。 -
x
,求 x x x 的子树的所有节点的所有颜色个数和(每个节点单独算)。
解题思路
首先我们还是对树求出 dfs
序,接着我们用线段树维护每个节点的丰富度。
难点在于怎么处理同种颜色的覆盖关系。
我们对于每个颜色建立一个 set
,保存被染上这种颜色的节点 dfs
序位置。
假设现在正在进行一个修改操作,将 x x x 染成 c c c 色:
-
找到
set
中 x x x 的前一个点,判断其是否是 x x x 的祖先,如果是的话直接跳过本次操作。 -
通过
set
找到 x x x 子树中已经被染上 c c c 色的节点。 -
将这些节点的子树权值总体减 1 1 1,同时把这些节点移出
set
。 -
将 x x x 的
dfs
序放入对应的set
,同时将 x x x 子树中的节点权值总体加 1 1 1。
解释一下第一步:
第一步在判断 x x x 的子树是否已经被整体染上了 c c c 色。
如果 x x x 的祖先节点已经被事先染上了 c c c 色,就不用再染了( x x x 的祖先节点的子树也已经被处理过了,所以不用管了)。
查询就很简单了,直接找线段树中 x x x 的子树部分的权值区间和就好。
AC CODE
#include <bits/stdc++.h>
using namespace std;
#define int long long
int read()
{
int x = 0;
char c = getchar();
while (c < '0' || c > '9')
c = getchar();
while (c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x;
}
#define _ 100007
int n, q;
int cnt_node, dfn[_], siz[_], b[_];
int tot, head[_], to[_ << 1], nxt[_ << 1];
void add(int u, int v)
{
to[++tot] = v;
nxt[tot] = head[u];
head[u] = tot;
}
void dfs(int u)
{
dfn[u] = ++cnt_node;
b[cnt_node] = u;
siz[u] = 1;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (dfn[v])
continue;
dfs(v);
siz[u] += siz[v];
}
}
int tr[_ << 2], tag[_ << 2];
void push_up(int o)
{
tr[o] = tr[o << 1] + tr[o << 1 | 1];
}
void push_down(int o, int l, int r)
{
if (tag[o])
{
int mid = (l + r) >> 1;
tr[o << 1] += (mid - l + 1) * tag[o];
tr[o << 1 | 1] += (r - mid) * tag[o];
tag[o << 1] += tag[o];
tag[o << 1 | 1] += tag[o];
tag[o] = 0;
}
}
void update(int o, int l, int r, int L, int R, int val)
{
if (L <= l && r <= R)
{
tr[o] += (r - l + 1) * val;
tag[o] += val;
return;
}
int mid = (l + r) >> 1;
push_down(o, l, r);
if (L <= mid)
update(o << 1, l, mid, L, R, val);
if (R > mid)
update(o << 1 | 1, mid + 1, r, L, R, val);
push_up(o);
}
int query(int o, int l, int r, int L, int R)
{
if (L <= l && r <= R)
return tr[o];
int mid = (l + r) >> 1;
push_down(o, l, r);
int res = 0;
if (L <= mid)
res += query(o << 1, l, mid, L, R);
if (R > mid)
res += query(o << 1 | 1, mid + 1, r, L, R);
return res;
}
std::set<int> st[_];
signed main()
{
n = read(), q = read();
for (int i = 1; i < n; i++)
{
int u = read(), v = read();
add(u, v);
add(v, u);
}
dfs(1);
while (q--)
{
int opt = read();
if (opt == 1)
{
int x = read(), c = read();
auto t = st[c].upper_bound(dfn[x]);
if (t != st[c].begin())
{
t--;
if (dfn[x] >= *t && dfn[x] <= *t + siz[b[*t]] - 1)
continue;
}
t = st[c].lower_bound(dfn[x]);
while (t != st[c].end() && *t <= dfn[x] + siz[x] - 1)
{
update(1, 1, n, *t, *t + siz[b[*t]] - 1, -1);
st[c].erase(t);
t = st[c].lower_bound(dfn[x]);
}
update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, 1);
st[c].insert(dfn[x]);
}
else
{
int x = read();
printf("%lld\n", query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
}
}
return 0;
}