题意 :给你一棵树,定义:若一个节点为一个节点的父亲节点,并且这两个节点的乘积小于等于k,那么称之为一条路径,问你一共有多少条路径。
题解 : 首先对一第一个条件我们不难想到dfs 因为在做dfs 的时候访问到的节点的祖先节点肯定都被访问过了 (这是dfs的特性)这样的话我们在访问每个节点的时候就可以将这个节点的 值扔进一棵线段树或者树状数组更新该节点,使得该节点的值 +1 ,访问完这个节点的所有子树后再将这个点更新 -1 然后对于一个新访问的节点查询 小于等于 k / a[u] 的个数 其实就是一个前缀和,就是线段树所围护的东西。 还有这个题目的数据有1e9 必须进行离散化,主要是我们要考虑到dfs的特性,访问它的时候已经访问过他的祖先,再考虑线段树维护一下就可以。
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#define ll long long
using namespace std;
const int maxn = 1e5 + 10;
struct node {
int l,r,sum;
}tr[maxn << 2];
//int cnt[maxn] = {0};
int n;
ll k,ans;
bool vis[maxn] = {0};
ll a[maxn] = {0};
vector <int> G[maxn];
vector <ll> v;
ll getid (ll x) { return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;}
void build (int l,int r,int root) {
tr[root].l = l;
tr[root].r = r;
tr[root].sum = 0;
if (l == r) return ;
int mid = (l + r) >> 1;
build(l, mid, root << 1);
build(mid +1, r, root << 1| 1);
}
void update (int root,int pos,int add) {
if (tr[root].l == tr[root].r) {
tr[root].sum += add;
return ;
}
int mid = (tr[root].l + tr[root].r) >> 1;
if (mid >= pos) update(root << 1, pos, add);
else update(root << 1 | 1, pos, add);
tr[root].sum = tr[root << 1].sum + tr[root << 1 | 1].sum;
}
int query (int l,int r,int root) {
// cout << l << ' ' << r << endl;
if (tr[root].l >= l && tr[root].r <= r)
return tr[root].sum;
int mid = (tr[root].l + tr[root].r) >> 1;
int temp1 = 0,temp2 = 0;
if (mid >= l) temp1 = query(l, r, root << 1);
if (mid < r) temp2 = query(l, r, root << 1 | 1);
return temp1 + temp2;
}
void dfs (int u) {
ll cnt = G[u].size();
ll y = k / a[u];
ll id = upper_bound (v.begin(),v.end(),y) - v.begin();
// cout << u << ' ' << ans << endl;
ans += query(1, id, 1);
// cout << u << ' ' << ans << endl;
id = getid (a[u]);
update(1, id, 1);
for (int i = 0;i < cnt; ++ i) {
int x = G[u][i];
dfs(x);
}
update(1, id, -1);
}
int main () {
ios_base :: sync_with_stdio(false);
int t;
cin >> t;
while (t--) {
cin >> n >> k;
build (1,n,1);
ans = 0;
memset (vis,0,sizeof (vis));
v.clear();
memset (a,0,sizeof (a));
for (int i = 1;i <= n ; ++ i) {
cin >> a[i];
v.push_back(a[i]);
G[i].clear();
}
sort (v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
for (int i = 1;i < n; ++ i) {
int x,y;
cin >> x >> y;
G[x].push_back(y);
vis[y] = 1;
}
int root = 1;
for (int i = 1;i <= n; ++ i)
if (!vis[i]) {
root = i;
break;
}
dfs (root);
cout << ans << endl;
}
return 0;
}