题目题意还是很清楚,要统计二元组(u,v)的对数,要同时满足两个条件:①u是v的祖先,②u*v<=k。那么我可以在dfs的时候将离散化后的每个节点的值用树状数组更新,
然后每到一个节点就用树状数组统计在祖先中有多少个小于k/a[now],记得在回溯的时候要消除兄弟的影响。
用map实现离散化版本:
#include<cstdio>
#include<cstring>
#include<vector>
#include<map>
#include<algorithm>
using namespace std;
#define LL long long
#define lowbit(x) x&(-x)
const int maxn=100010;
int t,n,u,v,a[maxn],c[maxn],q[maxn],rd[maxn];
vector<int> G[maxn];
map<LL,int> mp;
LL k,b[maxn];
void update(int i,int x)
{
while(i<=n)
{
q[i]+=x;
i+=lowbit(i);
}
}
LL getsum(int i)
{
LL res=0;
while(i)
{
res+=q[i];
i-=lowbit(i);
}
return res;
}
int find(LL v)
{
int l=1,r=n,mid;
LL pos;
while(l<=r)
{
mid=(l+r)>>1;
pos=b[mid];
if(pos==v) return mid;
else if(pos>v) r=mid-1;
else l=mid+1;
}
if(b[l]>v) return l-1;
else return l;
}
LL dfs(int now)
{
LL ans=0,x=find(k/a[now]);
ans+=getsum(x);
update(c[now],1);
int len=G[now].size();
for(int i=0;i<len;i++) ans+=dfs(G[now][i]);
update(c[now],-1);
return ans;
}
int main()
{
scanf("%d",&t);
while(t--)
{
memset(q,0,sizeof q);
mp.clear();
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
G[i].clear();
rd[i]=0;
b[i]=(LL)a[i];
}
sort(b+1,b+n+1);
for(int i=1;i<=n;i++) mp[b[i]]=i;
b[n+1]=4e18;
for(int i=1;i<=n;i++) c[i]=mp[a[i]];
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
G[u].push_back(v);
rd[v]++;
}
int rt=-1;
for(int i=1;i<=n;i++) if(!rd[i]) {rt=i;break;}
printf("%lld\n",dfs(rt));
}
return 0;
}
自己写的离散化(竟然比上面的跑得慢=.=)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define LL long long
#define lowbit(x) x&(-x)
const int maxn=100010;
int t,n,u,v,a[maxn],c[maxn],rd[maxn];
LL k,b[maxn];
vector<int> G[maxn];
void update(int i,int x)
{
while(i<=n)
{
c[i]+=x;
i+=lowbit(i);
}
}
LL getsum(int i)
{
LL res=0;
while(i)
{
res+=c[i];
i-=lowbit(i);
}
return res;
}
LL dfs(int now)
{
LL ans=0;
int len=G[now].size();
LL tmp1=upper_bound(b+1,b+n+1,(LL)k/a[now])-b;
LL tmp2=lower_bound(b+1,b+n+1,a[now])-b;
ans+=getsum(tmp1-1);
if(tmp2)update(tmp2,1);
for(int i=0;i<len;i++) ans+=dfs(G[now][i]);
if(tmp2)update(tmp2,-1);
return ans;
}
int main()
{
scanf("%d",&t);
while(t--)
{
memset(c,0,sizeof c);
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
b[i]=(LL)a[i];
rd[i]=0;
G[i].clear();
}
sort(b+1,b+n+1);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
G[u].push_back(v);
rd[v]++;
}
int rt=-1;
for(int i=1;i<=n;i++) if(!rd[i]){rt=i;break;}
printf("%lld\n",dfs(rt));
}
return 0;
}