题目大意
给你一个树,问你这棵树有多少对Weak Pair(一个对(u,v)成立当且仅当u是v的祖先并且
av∗au<=k
题目分析
对于一个棵树来说最容易想到的就是dfs遍历这棵树了吧!!假如说现在遍历到v节点,那么很明显之前遍历过的一定是u的祖先,那么我们就可以进行查找Weak Pair了,对于当前节点u对应的value值我们需要找到所有父亲节点中值小于k/value的节点个数,那么很明显对于父亲节点的值我们需要进行存储,但是父亲节点的值很大,因此我们需要离散化,为了快速查找父亲节点中值小于k/value的节点个数我们可以用线段树或者树状数组进行处理,进行线段树的单点更新以及区间查询即可。
我自己写这道题的时候遇到的一些问题在这里说一下:1.因为要进行离散化,所以需要加value和k/value存储下来,注意k/value可能超int,因此需要用long long保存。2.因为线段树所对应的数组大小一定要注意开大一点,因为value和k/value都存储下来之后相当于原来大小的2倍了,不开大可能会超时(我就超时了一次)。3.根节点不一定是1,因此在输入边的时候我们可以用一个数组来存一下每个点的入度,入度为0的点才是根节点。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 2e5+5;
typedef long long LL;
#define mid (L+R)/2
#define lson o<<1, L, mid
#define rson o<<1|1, mid+1, R
LL a[maxn], b[maxn], head[maxn], degree[maxn], sum[maxn<<2];
LL T, N, k, ans, tot;
struct Edge{ //邻接表
int to,next;
}e[maxn];
void update(int o,int L,int R,int v,int cnt){ //更新
if(L == R){
sum[o] += cnt;
return ;
}
if(v <= mid) update(lson, v, cnt);
else update(rson, v, cnt);
sum[o] = sum[o<<1] + sum[o<<1|1];
}
LL query(int o,int L,int R,int l,int r){ //查询
if(l <= L && R <= r) return sum[o];
LL ret = 0;
if(l <= mid) ret += query(lson, l, r);
if(r > mid) ret += query(rson, l, r);
return ret;
}
void addedge(int from, int to){
e[tot].to = to;
e[tot].next = head[from];
head[from] = tot++;
}
int BS(LL b[], int n, LL val){ //二分查找
int left = 1, right = n;
while(left <= right){
int m = (left + right)/2;
if(b[m] == val) return m;
else if(b[m] < val) left = m+1;
else right = m-1;
}
return -1;
}
void dfs(int u){
ans += query(1, 1, 2*N, 1, BS(b, 2*N, k/a[u]));
int x = BS(b, 2*N, a[u]);
update(1, 1, 2*N, x, 1);
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
dfs(v);
}
update(1, 1, 2*N, x, -1);
}
void init(){
tot = ans = 0;
memset(degree, 0, sizeof(degree));
memset(sum, 0, sizeof(sum));
memset(head, -1, sizeof(head));
}
int main(){
scanf("%I64d", &T);
while(T--){
init();
scanf("%I64d%I64d", &N, &k);
for(int i = 1; i <= N; i++){
scanf("%I64d", &a[i]);
b[i] = a[i];
b[i+N] = k/a[i];
}
sort(b+1, b+1+2*N);
int from, to;
for(int i = 1; i < N; i++){
scanf("%d%d", &from, &to);
addedge(from, to);
degree[to]++;
}
for(int i = 1; i <= N; i++)
if(!degree[i]){ dfs(i); break; }
printf("%I64d\n", ans);
}
return 0;
}