题目大意:(来源:https://www.luogu.org/problem/CF375D)
询问的是子树内的问题,一种做法是搞出每个点的dfs序,然后直接树上莫队。
还有dsu on tree的做法,维护一下当前子树内每个颜色出现的次数,然后用线段树或树状数组之类的数据结构维护一下每个出现次数的颜色之和
复杂度 O ( n n ) O(n \sqrt n) O(nn) 或 O ( n log n log n ) O(n \log n\log n) O(nlognlogn) (好像还有 O ( n log n ) O(n \log n) O(nlogn)的做法)
#include<bits/stdc++.h>
using namespace std;
#define lson rt << 1,l,mid
#define rson rt << 1 | 1,mid + 1,r
const int maxn = 1e5 + 10;
int sum[maxn << 2],color[maxn],cnt[maxn],vis[maxn];
int son[maxn],f[maxn],ans[maxn];
vector<int> g[maxn];
struct ss{
int k,id;
ss(int ki,int i) {
k = ki;id = i;
}
};
vector<ss> q[maxn];
int n,m;
void build(int rt,int l,int r) {
sum[rt] = 0;
if(l == r) return;
int mid = l + r >> 1;
build(lson);build(rson);
}
void upd(int p,int v,int rt,int l,int r) {
if(l == r) {
sum[rt] += v;
return ;
}
int mid = l + r >> 1;
if(p <= mid) upd(p,v,lson);
else upd(p,v,rson);
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
int qry(int L,int R,int rt,int l,int r) {
if(L > R) return 0;
if(L <= l && r <= R) return sum[rt];
int mid = l + r >> 1,ans = 0;
if(L <= mid) ans += qry(L,R,lson);
if(mid + 1 <= R) ans += qry(L,R,rson);
return ans;
}
void prework(int u,int fa) {
son[u] = 0,f[u] = 1;
for(auto it : g[u]) {
if(it == fa) continue;
prework(it,u);
f[u] += f[it];
if(!son[u] || f[son[u]] < f[it]) son[u] = it;
}
}
void modify(int u,int fa,int v) {
upd(cnt[color[u]],-1,1,0,maxn - 10);
cnt[color[u]] += v;
upd(cnt[color[u]],1,1,0,maxn - 10);
for(auto it : g[u]) {
if(vis[it] || it == fa) continue;
modify(it,u,v);
}
}
void dfs(int u,int fa,int keep) {
for(auto it : g[u]) {
if(it == fa || it == son[u]) continue;
dfs(it,u,0);
}
if(son[u]) dfs(son[u],u,1),vis[son[u]] = 1;
modify(u,fa,1);
for(auto it : q[u]) {
int p = it.id;
ans[p] = qry(it.k,maxn - 10,1,0,maxn - 10);
}
if(son[u]) vis[son[u]] = 0;
if(!keep) modify(u,fa,-1);
}
int main() {
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; i++)
scanf("%d",&color[i]);
for(int i = 1; i < n; i++) {
int u,v;scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= m; i++) {
int u,k;scanf("%d%d",&u,&k);
q[u].push_back(ss(k,i));
}
build(1,1,maxn - 10);
prework(1,0);
dfs(1,0,1);
for(int i = 1; i <= m; i++)
printf("%d\n",ans[i]);
return 0;
}