链接:戳这里
Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)
Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
题意:
n个节点的树,节点的点权为ai,要求找出有多少个二元组(u,v)满足
1:u是v的祖先且u!=v
2:a[u]*a[v]<=K
思路:
跑DFS的过程的时候,其实就是祖先到儿子的过程,但是会有兄弟的干扰,想象一下DFS序,我们要把兄弟节点删掉
然后DFS序里面的节点都是当前v的祖先,只需要快速找出有多少个祖先满足条件
那么我们考虑a[u]*a[v]<=K,则要快速找出前面的有多少个祖先a[u]满足a[u]<=K/a[v],线段树维护就可以了
这里数据很大所以要离散化
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<vector>
#include <ctime>
#include<queue>
#include<set>
#include<map>
#include<list>
#include<stack>
#include<iomanip>
#include<cmath>
#include<bitset>
#define mst(ss,b) memset((ss),(b),sizeof(ss))
///#pragma comment(linker, "/STACK:102400000,102400000")
typedef long long ll;
typedef long double ld;
#define INF (1ll<<60)-1
#define Max 1e9
using namespace std;
int T;
int n,m;
ll a[200100],b[200100],K;
int deep[100100];
int sum[800100];
void build(int root,int l,int r){
if(l==r) {
sum[root]=0;
return ;
}
int mid=(l+r)/2;
build(root*2,l,mid);
build(root*2+1,mid+1,r);
sum[root]=sum[root*2]+sum[root*2+1];
}
int query(int root,int l,int r,int x,int y){
if(x<=l && y>=r) return sum[root];
int mid=(l+r)/2;
if(y<=mid) return query(root*2,l,mid,x,y);
else if(x>mid) return query(root*2+1,mid+1,r,x,y);
else return query(root*2,l,mid,x,mid)+query(root*2+1,mid+1,r,mid+1,y);
}
void update(int root,int l,int r,int x,ll v){
if(l==r) {
sum[root]+=v;
return ;
}
int mid=(l+r)/2;
if(x<=mid) update(root*2,l,mid,x,v);
else update(root*2+1,mid+1,r,x,v);
sum[root]=sum[root*2]+sum[root*2+1];
}
struct edge{
int v,next;
}e[200100];
int head[100100],tot=0;
void Add(int u,int v){
e[tot].v=v;
e[tot].next=head[u];
head[u]=tot++;
}
ll ans;
void DFS(int u){
int l=lower_bound(b+1,b+m+1,K/a[u])-b;
int pos=lower_bound(b+1,b+m+1,a[u])-b;
ans+=1LL*query(1,1,m,1,l);
update(1,1,m,pos,1);
for(int i=head[u];i!=-1;i=e[i].next) DFS(e[i].v);
update(1,1,m,pos,-1);
}
int main(){
scanf("%d",&T);
for(int cas=1;cas<=T;cas++){
mst(deep,0);
mst(head,-1);
mst(sum,0);
ans=tot=0;
scanf("%d%I64d",&n,&K);
for(int i=1;i<=n;i++){
scanf("%I64d",&a[i]);
b[i]=a[i];
}
m=n;
for(int i=1;i<=n;i++) b[++m]=K/a[i];
sort(b+1,b+m+1);
m=unique(b+1,b+m+1)-(b+1);
build(1,1,m);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
Add(u,v);
deep[v]++;
}
for(int i=1;i<=n;i++){
if(deep[i]==0) DFS(i);
}
printf("%I64d\n",ans);
}
return 0;
}