题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877
题目大意:给你一棵树,有n个节点,每个点都有对应的ai,不同结点之间有关系,问其上的任意两个节点i,j,这两个节点满足那个关系,且ai*aj<=k。这样的节点对有多少个。
解题思路:
因为ai的值较大,所以需要离散化(很重要,而且我写了好长时间,题意都理解错了,真是够了)。
我们可以将ai*aj<=k,转化为ai<=k/aj;但要注意对于aj为零时的特殊处理。
我们可以从入度为零的父亲节点开始向下遍历,所遍历到的节点,与已经遍历的节点一定满足那个关系,所以可以直接统计当前节点上满足的有多少,最后要注意将遍历的节点删除,因为此节点与另一子树上的节点并不满足那个关系。
具体细节看AC代码:
#include<iostream>//离散化很重要
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
ll t,n;
ll k,sum;
const ll INF=1e5+10;
ll a[INF*2];
ll vis[INF],c[INF*2];
vector<ll>v[INF];
void prepare(ll *x)//离散化
{
ll data[2*INF];
for(ll i=1;i<=2*n;i++)
{
data[i]=x[i];
}
sort(data+1,data+1+2*n);
ll m=unique(data+1,data+1+2*n)-data-1;
for(ll i=1;i<=2*n;i++)
{
x[i]=lower_bound(data+1,data+m+1,x[i])-data;
}
}
ll lowbit(ll x)
{
return x&(-x);
}
void add(ll x,ll y)
{
while(x<=2*n)
{
c[x]=c[x]+y;
x=x+lowbit(x);
}
}
ll getsum(ll x)
{
ll ans=0;
while(x>0)
{
ans+=c[x];
x=x-lowbit(x);
}
return ans;
}
void dfs(ll x,ll m)//注意用什么更新数组c,用什么统计数目
{
ll val=getsum(a[x+n]);
sum=sum+val;
ll cnt=v[x].size() ;
ll i,j;
add(a[x],1);
for(i=0;i<cnt;i++)
{
j=v[x][i];
dfs(j,m);
}
add(a[x],-1);//删除该节点
}
int main()
{
cin>>t;
while(t--)
{
scanf("%lld%lld",&n,&k);
for(ll i=0;i<=n;i++)
v[i].clear() ;
ll i,j,val,l;
memset(vis,0,sizeof(vis));
memset(c,0,sizeof(c));
for(i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
if(j==0)a[i+n]=(1e18)+5;
else a[i+n] =k/a[i];
}
for(i=1;i<n;i++)
{
scanf("%lld%lld",&j,&val);
v[j].push_back(val);
vis[val]++;
}
//离散化
prepare(a);
for(i=1;i<=n;i++)
{
if(vis[i]==0)break;
}
sum=0;
dfs(i,0);
printf("%lld\n",sum);
}
return 0;
}
#include<iostream>//另一种思路,区别在dfs上
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
ll t,n;
ll k,sum;
const ll INF=1e5+10;
ll a[INF*2];
ll vis[INF],c[INF*2];
vector<ll>v[INF];
void prepare(ll *x)
{
ll data[2*INF];
for(ll i=1;i<=2*n;i++)
{
data[i]=x[i];
}
sort(data+1,data+1+2*n);
ll m=unique(data+1,data+1+2*n)-data-1;
for(ll i=1;i<=2*n;i++)
{
x[i]=lower_bound(data+1,data+m+1,x[i])-data;
}
}
ll lowbit(ll x)
{
return x&(-x);
}
void add(ll x,ll y)
{
while(x<=2*n)
{
c[x]=c[x]+y;
x=x+lowbit(x);
}
}
ll getsum(ll x)
{
ll ans=0;
while(x>0)
{
ans+=c[x];
x=x-lowbit(x);
}
return ans;
}
void dfs(ll x,ll m)//这种思路也可以
{
ll val=getsum(a[x+n]);
sum=sum-val;//这样操作保证了不重复和多算
ll cnt=v[x].size() ;
ll i,j;
for(i=0;i<cnt;i++)//遍历每个与x满足那个关系的点
{
j=v[x][i];
add(a[j],1);
dfs(j,m);
}
val=getsum(a[x+n]);
sum=sum+val;
}
int main()
{
cin>>t;
while(t--)
{
scanf("%lld%lld",&n,&k);
for(ll i=0;i<=n;i++)
v[i].clear() ;
ll i,j,val,l;
memset(vis,0,sizeof(vis));
memset(c,0,sizeof(c));
for(i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
if(j==0)a[i+n]=(1e18)+5;
else a[i+n] =k/a[i];
}
for(i=1;i<n;i++)
{
scanf("%lld%lld",&j,&val);
v[j].push_back(val);
vis[val]++;
}
//离散化
prepare(a);
for(i=1;i<=n;i++)
{
if(vis[i]==0)break;
}
sum=0;
dfs(i,0);
printf("%lld\n",sum);
}
return 0;
}