解题思路:
【题意】
给你一棵有根树,一个定值k,以及树上每个结点的值a[i]
对于有序对(u,v),如果(1)u是v的祖先,且(2)a[u]*a[v]<=k,则称该有序对(u,v)是弱的
问树中有多少对有序对(u,v)是弱的
【类型】
离散化+dfs+树状数组
【分析】
对于要求(1),u是v的祖先,我们可以采取dfs
遍历到v时,它上方的所有结点必定都是满足第一条件的u
熟悉dfs过程的应该能理解这一点,不理解的可以借助下述图片稍微理解一下
从上图中,我们可以大致看出dfs过程是从树根开始向树叶访问的
对于某结点v,它的祖先u肯定是先于它被访问的,不然也不可能到达结点v
正如上图,结点10的祖先是结点1,2,4,8,不管哪个祖先,一旦有一个没被访问,也不可能达到结点10
此外,在退出某个子树的时候,该子树下结点的影响会被消除,这样就能保证所有有影响的都是祖先
要求(2),a[u]*a[v]<=k,那么到v的时候,所有小于等于k/a[v]的u都满足,可以想到树状数组
结点的值a[i]最大10亿,要用树状数组的话肯定要离散化
离散化的时候要把k/a[v]加进去一起离散,保证大小关系不变
另外,当a[i]=0时,会出现除以0错误,所以我们要特判该情况
显然a[i]=0的话,任何满足要求(1)的结点都可以构成弱的有序对
所以将该条件下的k/a[i]的结果直接设置为inf
#include<bits/stdc++.h>
using namespace std;
#define INF (1ll<<60)-1
#define LL long long
#define N 100005
LL k, ans, a[N], b[N*2], sum[N<<4];
int deep[N], head[N], tol, m;
struct Edge
{
int v, nxt;
}edge[N];
void init()
{
ans = tol = 0;
memset(sum, 0, sizeof(sum));
memset(deep, 0, sizeof(deep));
memset(head, -1, sizeof(head));
}
void build(int rt, int left, int right)
{
if(left == right)
{
sum[rt] = 0;
return ;
}
int mid = (left+right)>>1;
build(rt<<1, left, mid);
build(rt<<1|1, mid+1, right);
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void addedge(int u, int v)
{
edge[tol].v = v;
edge[tol].nxt = head[u];
head[u] = tol++;
}
int query(int rt, int left, int right, int l, int r)
{
if(l<=left&&r>=right) return sum[rt];
int mid = (left + right) >> 1;
if(r <= mid) return query(rt<<1, left, mid, l, r);
else if(l > mid) return query(rt<<1|1, mid+1, right, l, r);
else return query(rt<<1, left, mid, l, r) + query(rt<<1|1, mid+1, right, l, r);
}
void update(int rt, int left, int right, int pos, int val)
{
if(left == right)
{
sum[rt] += val;
return ;
}
int mid = (left+right)>>1;
if(pos <= mid) update(rt<<1, left, mid, pos, val);
else update(rt<<1|1, mid+1, right, pos, val);
sum[rt] = sum[rt<<1]+sum[rt<<1|1];
}
void dfs(int u)
{
LL lim;
if(a[u] == 0) lim = INF;
else lim = k/a[u];
int l = lower_bound(b+1, b+m+1, lim) - b;
int pos = lower_bound(b+1, b+m+1, a[u]) - b;
ans += query(1, 1, m, 1, l);
update(1, 1, m, pos, 1);
for(int i = head[u]; i != -1; i = edge[i].nxt) dfs(edge[i].v);
update(1, 1, m, pos, -1);
}
void solve()
{
int n;
init();
scanf("%d%I64d", &n, &k);
for(int i = 1; i <= n;i++)
{
scanf("%I64d", &a[i]);
b[i] = a[i];
if(a[i]!=0)
b[i+n] = k/a[i];
else
b[i+n] = INF;
}
sort(b+1, b+n*2+1);
m = unique(b+1, b+n*2+1) - (b+1);
build(1, 1, m);
for(int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v);
deep[v]++;
}
for(int i = 1; i <= n; i++)
{
if(deep[i] == 0)
{
dfs(i);
break;
}
}
printf("%I64d\n", ans);
}
int main()
{
int t;
scanf("%d", &t);
while(t--)
{
solve();
}
return 0;
}