题目:
一棵有n个节点的树,每个点有一个权值。给出k,要求找出树上有多少个数对 (u,v) 满足如下条件:
- u是v的祖先。(u,v不同)
- au∗av≤k
分析:
利用dfs的过程,在dfs遍历一棵树的时候,若此时在v点,那么可以保证的是v点所有的祖先都已经被遍历了。所以,我们可以把v之前的点加入线段树,到v的时候直接查询出在祖先中有多少个 ≤av 。但是,有一些点是v的兄弟,有可能也已经入树了。所以我们想把v的兄弟“踢出去”。可以这样实现:当一个节点的子树已经被遍历了,那这个点就应该被踢出来,因为它不再会是某个未访问点的祖先。
这题的数据比较大,需要离散化。离散化后大数对应大下标,小数对应小下标,所以想查询有多少个 ≤k 的数时只需要看在区间 [1,rk] 有多少个数即可,简单的区间求和。
代码:
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <cstdio>
using namespace std;
#define ms(a,b) memset(a,b,sizeof(a))
#define lson rt*2,l,(l+r)/2
#define rson rt*2+1,(l+r)/2+1,r
typedef unsigned long long ull;
typedef long long ll;
const int MAXN=2e5+5;
const double EPS=1e-8;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int MOD = 1e9+7;
struct{
int v,next;
}e[MAXN];
int head[MAXN],tot,n,m,tree[MAXN << 2],deep[MAXN];
ll ans,k,a[MAXN],b[MAXN];
void init(){
tot = 0;
ans = 0;
ms(head,-1);
ms(tree,0);
ms(deep,0);
}
void addedge(int u,int v){
e[tot].v = v;
e[tot].next = head[u];
head[u] = tot++;
}
void pushup(int rt){
tree[rt] = tree[rt<<1] + tree[rt<<1|1];
}
ll query(int L,int R,int rt,int l,int r){
if(L<=l && R>=r){
return tree[rt];
}
if(R<=(l+r)/2) return query(L,R,lson);
else if(L > (l+r)/2) return query(L,R,rson);
else return query(L,R,lson) + query(L,R,rson);
}
void update(int rt,int l,int r,int p,int x){
if(l==r){
tree[rt] += x;
return;
}
if(p <= (l+r)/2) update(lson,p,x);
else update(rson,p,x);
pushup(rt);
}
void dfs(int v){
int r = lower_bound(b+1,b+m+1,k/a[v]) - b;
int pos = lower_bound(b+1,b+m+1,a[v]) - b;
ans += query(1,r,1,1,m);
update(1,1,m,pos,1);
for(int i=head[v];i!=-1;i=e[i].next){
dfs(e[i].v);
}
update(1,1,m,pos,-1);
}
int main(){
int T;
scanf("%d",&T);
while(T--){
init();
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
b[i] = a[i];
if(a[i]!=0) b[n+i] = k/a[i];
else b[n+i] = INF;
}
sort(b+1,b+2*n+1);
m = unique(b+1,b+2*n+1)-(b+1);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
deep[v]++;
}
for(int i=1;i<=n;i++){
if(deep[i] == 0){
dfs(i);
break;
}
}
printf("%lld\n",ans);
}
return 0;
}