离线每个询问,然后做树分治。。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define lowbit(x) (x&(-x))
#define pii pair<int, int>
#define mp(x, y) make_pair(x, y)
const int maxn = 100005;
const int maxm = 200005;
const int INF = 0x3f3f3f3f;
struct Edge
{
int v;
Edge *next;
}*H[maxn], *edges, E[maxm];
vector<pii> q[maxn], dis1, dis2, dis;
bool done[maxn];
int size[maxn];
int res[maxn];
int mx[maxn];
int a[maxn];
int tree[maxn];
int tree1[maxn];
int tree2[maxn];
int n, m, root, nsize;
void addedges(int u, int v)
{
edges->v = v;
edges->next = H[u];
H[u] = edges++;
}
void init()
{
edges = E;
memset(H, 0, sizeof H);
memset(res, 0, sizeof res);
memset(done, 0, sizeof done);
}
void getroot(int u, int fa)
{
mx[u] = 0, size[u] = 1;
for(Edge *e = H[u]; e; e = e->next) if(!done[e->v] && e->v != fa) {
int v = e->v;
getroot(v, u);
size[u] += size[v];
mx[u] = max(mx[u], size[v]);
}
mx[u] = max(mx[u], nsize - size[u]);
if(mx[u] < mx[root]) root = u;
}
void add(int x, int v, int tree[])
{
x++;
for(int i = x; i <= n + 1; i += lowbit(i)) tree[i] += v;
}
int sum(int x, int tree[])
{
x++;
int ans = 0;
for(int i = x; i > 0; i -= lowbit(i)) ans += tree[i];
return ans;
}
void dfs(int u, int fa, int dep, int flag)
{
if(flag == 0) dis.push_back(mp(dep, u));
if(flag == 1) dis1.push_back(mp(dep, u));
if(flag == 2) dis2.push_back(mp(dep, u));
for(Edge *e = H[u]; e; e = e->next) if(e->v != fa && !done[e->v]) {
int v = e->v;
if(flag == 0) {
if(a[u] == a[v]) dfs(v, u, dep + 1, 0);
if(a[u] > a[v]) dfs(v, u, dep + 1, 1);
if(a[u] < a[v]) dfs(v, u, dep + 1, 2);
}
else if(flag == 1) {
if(a[u] >= a[v]) dfs(v, u, dep + 1, 1);
}
else {
if(a[u] <= a[v]) dfs(v, u, dep + 1, 2);
}
}
}
void solve(int u)
{
done[u] = true;
dis.clear();
dis1.clear();
dis2.clear();
dfs(u, u, 0, 0);
for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);
for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);
for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);
for(int i = 0; i < q[u].size(); i++) {
int t = 0, d = q[u][i].first, id = q[u][i].second;
t = sum(d, tree) + sum(d, tree1) + sum(d, tree2);
res[id] += t;
}
for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {
int v = e->v;
dis.clear();
dis1.clear();
dis2.clear();
if(a[u] == a[v]) dfs(v, v, 1, 0);
if(a[u] > a[v]) dfs(v, v, 1, 1);
if(a[u] < a[v]) dfs(v, v, 1, 2);
for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);
for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);
for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);
for(int i = 0; i < dis.size(); i++) {
int dist = dis[i].first, x = dis[i].second;
for(int j = 0; j < q[x].size(); j++) {
int t = 0, d = q[x][j].first, id = q[x][j].second;
if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1) + sum(d - dist, tree2);
res[id] += t;
}
}
for(int i = 0; i < dis1.size(); i++) {
int dist = dis1[i].first, x = dis1[i].second;
for(int j = 0; j < q[x].size(); j++) {
int t = 0, d = q[x][j].first, id = q[x][j].second;
if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree2);
res[id] += t;
}
}
for(int i = 0; i < dis2.size(); i++) {
int dist = dis2[i].first, x = dis2[i].second;
for(int j = 0; j < q[x].size(); j++) {
int t = 0, d = q[x][j].first, id = q[x][j].second;
if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1);
res[id] += t;
}
}
for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);
for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);
for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);
}
dis.clear();
dis1.clear();
dis2.clear();
dfs(u, u, 0, 0);
for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);
for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);
for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);
for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {
int v = e->v;
mx[0] = nsize = size[v];
getroot(v, root = 0);
solve(root);
}
}
void work()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
addedges(u, v);
addedges(v, u);
}
for(int i = 1; i <= n; i++) q[i].clear();
for(int i = 1; i <= m; i++) {
int x, d;
scanf("%d%d", &x, &d);
q[x].push_back(mp(d, i));
}
mx[0] = nsize = n;
getroot(1, root = 0);
solve(root);
for(int i = 1; i <= m; i++) printf("%d\n", res[i]);
}
int main()
{
int _;
scanf("%d", &_);
while(_--) {
init();
work();
}
return 0;
}