用C++交会栈溢出,而G++不会。
更新和查询我用的是线段树,1500+ms,用树状数组应该会快一些。
将树形结构转换成线性结构后,等价于求指定区间内恰好出现k次的数有多少个。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
using namespace std;
const int maxn = 100010;
int t, n, k, q;
int w[maxn], a[maxn];
int L[maxn], R[maxn]; // 询问区间的左、右端点
int ans[maxn], id;
vector<int> vt[maxn]; // 临接表
vector<int> vv[maxn];
bool vis[maxn];
map<int, int> mp;
struct Query {
int l, r, id;
}Q[maxn];
// for segment tree:
int add[maxn<<2];
bool cmp(Query q1, Query q2)
{
return q1.r < q2.r;
}
void dfs(int x)
{ // 将树形结构变成线性结构
vis[x] = true;
L[x] = id;
a[id] = w[x];
int size = vt[x].size();
for (int i = 0; i < size; ++i) {
if (!vis[vt[x][i]]) {
id++;
dfs(vt[x][i]);
}
}
R[x] = id;
}
void pushDown(int rt)
{
if (add[rt]) {
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
add[rt] = 0;
}
return ;
}
void build(int l, int r, int rt)
{
add[rt] = 0;
if (l == r) return ;
int m = (l + r) >> 1;
build(l, m, rt << 1);
build(m + 1, r, rt << 1 | 1);
}
void update(int l, int r, int rt, int L, int R, int c)
{
if (L <= l && R >= r) {
add[rt] += c;
return ;
}
pushDown(rt);
int m = (l + r) >> 1;
if (L <= m) {
update(l, m, rt << 1, L, R, c);
}
if (R > m) {
update(m + 1, r, rt << 1 | 1, L, R, c);
}
}
int query(int l, int r, int rt, int p)
{
if (l == r) {
return add[rt];
}
pushDown(rt);
int m = (l + r) >> 1;
if (p <= m) {
return query(l, m, rt << 1, p);
} else {
return query(m + 1, r, rt << 1 | 1, p);
}
}
int main()
{
scanf("%d", &t);
for (int cas = 1; cas <= t; ++cas) {
scanf("%d%d", &n, &k);
mp.clear();
id = 1;
for (int i = 1; i <= n; ++i) {
scanf("%d", &w[i]);
// 离散化
if (mp[w[i]] == 0) {
mp[w[i]] = id++;
}
w[i] = mp[w[i]];
}
int u, v;
for (int i = 0; i < maxn; ++i) {
vt[i].clear();
vv[i].clear();
}
for (int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
vt[u].push_back(v);
vt[v].push_back(u);
}
memset(vis, false, sizeof(vis));
id = 1;
dfs(1);
scanf("%d", &q);
for (int i = 0; i < q; ++i) {
scanf("%d", &u);
Q[i].id = i;
Q[i].l = L[u];
Q[i].r = R[u];
}
sort(Q, Q + q, cmp);
build(1, n, 1);
int idx = 0;
for (int i = 1; i <= n; ++i) {
// 线段树第j个数表示[j, i]间出现k次的数的个数
int num = a[i];
vv[num].push_back(i);
int size = vv[num].size();
if (size >= k) {
if (size > k) {
// 1 ~ vv[num][size-k-1]都减1
update(1, n, 1, 1, vv[num][size-k-1], -1);
// vv[num][size-k-1]+1 ~ vv[num][size-k]都加1
update(1, n, 1, vv[num][size-k-1] + 1, vv[num][size-k], 1);
} else {
// 加1
update(1, n, 1, 1, vv[num][size-k], 1);
}
}
while (Q[idx].r == i) {
ans[Q[idx].id] = query(1, n, 1, Q[idx].l);
idx++;
}
}
if (cas != 1) {
printf("\n");
}
printf("Case #%d:\n", cas);
for (int i = 0; i < q; ++i) {
printf("%d\n", ans[i]);
}
}
return 0;
}