Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 1556 Accepted Submission(s): 501
Problem Description
You are given a
rooted
tree of
N
nodes, labeled from 1 to
N
. To the
i
th node a non-negative value
ai
is assigned.An
ordered
pair of nodes
(u,v)
is said to be
weak
if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k .
Can you find the number of weak pairs in the tree?
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k .
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k , respectively.
The second line contains N space-separated integers, denoting a1 to aN .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k , respectively.
The second line contains N space-separated integers, denoting a1 to aN .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1 2 3 1 2 1 2
Sample Output
1
Source
Recommend
题意:给你一棵树,树上每个点都有权值,再给你它的父子关系,问最后在同一棵子树上的任意两个点权值相乘<=k的有多少对。
思路:看别人的做法都是用splay做的,表示我还没学会就用了其他方法做了,先把点权值和k/点权值离散化,然后dfs树,每到一个节点就用线段树查询当前父节点有多少个<=k/该点权值,然后再用线段树单点增加,返回时再单点删除就好了。下面给代码:
#include<iostream>
#include<cmath>
#include<queue>
#include<cstdio>
#include<queue>
#include<algorithm>
#include<cstring>
#include<string>
#include <functional>
#define maxn 100005
#define inf 0x3f3f3f3f
#define lson l1,mid,l2,r2,now<<1
#define rson mid+1,r1,l2,r2,now<<1|1
#define ll now<<1
#define rr now<<1|1
typedef long long LL;
using namespace std;
int vis[maxn], value[maxn], head[maxn], sumnum[maxn << 3];
LL x[maxn << 1];
int length1, ans, n, length2;
LL k;
struct node{
int v, next;
}p[maxn];
void build(int l, int r, int now){
if (l == r){
sumnum[now] = 0;
return;
}
int mid = (l + r) >> 1;
build(l, mid, ll);
build(mid + 1, r, rr);
sumnum[now] = 0;
}
void add(int l1, int r1, int num, int now){
if (l1 == r1){
sumnum[now]++;
return;
}
int mid = (l1 + r1) >> 1;
if (num <= mid)
add(l1, mid, num, ll);
else
add(mid + 1, r1, num, rr);
sumnum[now] = sumnum[ll] + sumnum[rr];
}
void del(int l1, int r1, int num, int now){
if (l1 == r1){
sumnum[now]--;
return;
}
int mid = (l1 + r1) >> 1;
if (num <= mid)
del(l1, mid, num, ll);
else
del(mid + 1, r1, num, rr);
sumnum[now] = sumnum[ll] + sumnum[rr];
}
int query(int l1, int r1, int l2, int r2, int now){
if (l2 <= l1&&r2 >= r1){
return sumnum[now];
}
int mid = (l1 + r1) >> 1;
int a = 0, b = 0;
if (l2 <= mid)
a = query(lson);
if (r2>mid)
b = query(rson);
return a + b;
}
void dfs(int son){
int num = lower_bound(x + 1, x + length1 + 1, k / value[son]) - x;
ans += query(1, length1, 1, num, 1);
int pos = lower_bound(x + 1, x + length1 + 1, value[son]) - x;
add(1, length1, pos, 1);
for (int i = head[son]; ~i; i = p[i].next){
dfs(p[i].v);
}
del(1, length1, pos, 1);
}
int main(){
int t;
scanf("%d", &t);
while (t--){
scanf("%d%lld", &n, &k);
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
ans = 0;
for (int i = 1; i <= n; i++){
scanf("%d", &value[i]);
x[i] = value[i];
}
for (int i = n + 1; i <= (n << 1); i++){
x[i] = k / value[i - n];
}
sort(x + 1, x + (n << 1) + 1);
length1 = unique(x + 1, x + (n << 1) + 1) - x;
build(1, length1, 1);
length2 = 0;
for (int i = 0; i<n - 1; i++){
int u, v;
scanf("%d%d", &u, &v);
vis[v] = 1;
p[length2].v = v;
p[length2].next = head[u];
head[u]=length2++;
}
for (int i = 1; i <= n; i++){
if (!vis[i])
dfs(i);
}
printf("%d\n", ans);
}
}