Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 2421 Accepted Submission(s): 745
Problem Description
You are given a
rooted
tree of
N
nodes, labeled from 1 to
N
. To the
i
th node a non-negative value
ai
is assigned.An
ordered
pair of nodes
(u,v)
is said to be
weak
if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k .
Can you find the number of weak pairs in the tree?
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k .
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k , respectively.
The second line contains N space-separated integers, denoting a1 to aN .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k , respectively.
The second line contains N space-separated integers, denoting a1 to aN .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1 2 3 1 2 1 2
Sample Output
1
数据范围比较大,直接用树状数组肯定不行,需要离散化,之后用dfs扫一遍就好啦。第一次在正式比赛中使用树状数组过题,开森!以前用线段树写的时候总是会写残,突然发现树状数组实在是好用!!AC代码如下:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<string>
#include<map>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
typedef long long ll;//ÀëÉ¢»¯
typedef struct Node{
int v;
int pos, num;
}Node;
int v[100010];
Node node[100010];
int c[100010], f[100010], vis[100010];
vector<int>e[100010];
ll ans = 0, k;
int n;
int find(int a)
{
if(f[a] == a) return a;
else
return find(f[a]);
}
bool cmp1(Node a, Node b)
{
return a.v < b.v;
}
bool cmp2(Node a, Node b)
{
return a.num < b.num;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x, int val)
{
for(; x <= n; x += lowbit(x)) c[x] += val;
}
ll query(int x)
{
if(x == 0) return 0;
ll sum = 0;
for(; x > 0; x -= lowbit(x)) sum += (ll)c[x];
return sum;
}
void dfs(int x)
{
vis[x] = 1;
int len = e[x].size();
if(len == 0)
return;
for(int i = 0; i < len; i++)
{
int t = e[x][i];
int pos = upper_bound(v+1, v+n+1, k/node[t].v)-v-1;
ans += query(pos);
update(node[t].pos, 1);
if(!vis[t])
dfs(t);
update(node[t].pos, -1);
}
}
int main()
{
int T;
scanf("%d", &T);
while(T--)
{
ans = 0;
scanf("%d%lld", &n, &k);
memset(c, 0, sizeof(c));
memset(vis, 0, sizeof(vis));
for(int i = 1; i <= n; i++)
{
e[i].clear();
scanf("%d", &v[i]);
node[i].v = v[i];
node[i].num = i;
f[i] = i;
}
stable_sort(v+1, v+n+1);
sort(node+1, node+n+1, cmp1);
for(int i = 1; i <= n; i++)
{
node[i].pos = i;
}
sort(node+1, node+n+1, cmp2);
for(int i = 1; i <= n-1; i++)
{
int a, b;
scanf("%d%d", &a, &b);
e[a].push_back(b);
f[b] = a;
}
int rt = find(1);
int pos = node[rt].pos;
update(pos, 1), dfs(rt);
printf("%lld\n", ans);
}
return 0;
}