目录
题意:
给你一颗树,计算有多少个(u,v){u是v的祖先},满足a[u]*a[v]<=k。
数据范围:
1≤N≤10^5
0≤ai≤10^9
0≤k≤10^18
思路:
根据范围,只能遍历一遍,dfs深搜遍历。dfs深搜有个特性,对于搜到的一个点v,此时经过的点都是他的祖先(同层次的回溯之后就已经不算了,已经从这个树枝出来了)。于是边搜索边记忆所有的结点值,每到新节点v时就暴力遍历看看那些满足。显然超时,根据这个思路,想想如何优化找到多少个之前满足<=k/a[v]的结点值,也就是查找小于等于k/a[v]的个数。于是联想到了----树状数组!(单点修改+1,区间求和个数)
但是元素范围到10^9,数组开不了------离散化(排序+去重)
离散化,可以理解为把数值范围缩小,比如a[1]=1,a[2]=10000。这时候1对应1,2对应10000,把范围从10000缩小到了2。从数据范围缩小到了数据数量。因为我们要查找<=k/a[v]的,所以离散化所有的k/a[v]。注意a[v]=0时,所有的都可以满足,更新为INF。
for(int i = 1;i <= n;i++){
vec[i].clear();
cin >> a[i];
if(!a[i]) s[p++]=INF;
else s[p++] = k/a[i];
}
sort(s,s+p);//排序
p = unique(s,s+p)-s;//去重
这样就巧妙的用个数来替代元素值。当经过了结点u时,就将a[u]放入到树状数组中更新,因为值不是连续的,所以可能不存在a[u],这时候寻找第一个大于等于a[u]的进行+1操作即可。
原因:对于一个结点k/a[i]来说,所有<=k/a[i]的都满足条件(区间求和的原因),所以这时候把a[u]向上取结点即可,如果向下的话,会造成k/a[i]的少一种情况。
还有一种情况就是向下取整,k/a[i]=x.xxxx 这种情况直接取整即可。
原因:所有值都为整数,因为<=k/a[i]的都满足,所有向下取整不会有影响,向上取整的话就造成k/a[i]+1多一个。
代码:
#include<algorithm>
#include<iostream>
#include<cstring>
#include<string>
#include<cstdlib>
#include<map>
#include<cmath>
#include<vector>
#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn = 1e5+50;
const int INF = 0x3f3f3f3f;
int tree[maxn];
int n,p;
ll k;
int a[maxn];
int s[maxn];
bool vis[maxn];
ll ans;
vector<int>vec[maxn];
int lowbit(int x){
return x&(-x);
}
void update(int x,int val){
for(;x<= n;x+=lowbit(x)) tree[x]+=val;
}
int getsum(int x){
int ans = 0;
for(;x;x-=lowbit(x)) ans+=tree[x];
return ans;
}
void dfs(int u){
ans += getsum(lower_bound(s,s+p,a[u]?k/a[u]:INF)-s+1);
update(lower_bound(s,s+p,a[u])-s+1,1);
for(int i = 0;i < vec[u].size();i++){
int v = vec[u][i];
dfs(v);
}
update(lower_bound(s,s+p,a[u])-s+1,-1);
}
int main(){
int t;
cin >> t;
while(t--){
cin >> n >> k;
p=0;ans=0;
memset(tree,0,sizeof(tree));
memset(vis,false,sizeof(vis));
for(int i = 1;i <= n;i++){
vec[i].clear();
cin >> a[i];
if(!a[i]) s[p++]=INF;
else s[p++] = k/a[i];
}
sort(s,s+p);
p = unique(s,s+p)-s;
int u,v;
for(int i = 0;i < n-1;i++){
cin >> u >> v;
vis[v] = true;
vec[u].push_back(v);
}
for(int i = 1;i <= n;i++){
if(!vis[i]) dfs(i);
}
cout<<ans<<endl;
}
return 0;
}
大佬代码,清晰易懂
https://www.cnblogs.com/ctyakwf/p/12285090.html
更亲近一些,离散化a[i]数组,对于每一个经过结点u,进入下一个结点时,加入该结点进树状数组。每次加入必定可以精确到根节点。
再到每一个结点v时,查找所有小于等于k/a[v]的数,注意一点的时,如果k/a[v]没有精确到根节点,值在两个结点代表值之间的话,因为所有<=k/a[v]的点都可以,所以需要向下取结点,如果相等就相等。向上的话可以用lower_bound,这个可以用upper_bound找到第一个大于k/a[v]的后再减一。即:
int x=upper_bound(num.begin(),num.end(),k/a[u])-num.begin();
对于k/a[v]向下取整,与上面一个道理。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<map>
#include<set>
#include<vector>
using namespace std;
typedef long long ll;
const int N=5e5+10;
const ll inf=0x3f3f3f3f;
int h[N],ne[N],e[N],in[N];
int idx;
int a[N];
int b[N];
ll n;
int tr[N];
int size;
ll k;
ll ans;
vector<int> num;
void add(int a,int b){
//链式前向星,vector也可以
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int lowbit(int x){
return x&-x;
}
void modify(int x,int c){
int i;
for(i=x;i<=n;i+=lowbit(i)){
tr[i]+=c;
}
}
ll sum(int x){
int i;
ll res=0;
for(i=x;i;i-=lowbit(i)){
res+=tr[i];
}
return res;
}
void dfs(ll u){
int i;
int x=upper_bound(num.begin(),num.end(),a[u]?k/a[u]:inf)-num.begin();
//这里修改了一下a[u]=的情况。
ans+=sum(x);
int y=lower_bound(num.begin(),num.end(),a[u])-num.begin()+1;
modify(y,1);
for(i=h[u];i!=-1;i=ne[i]){
int j=e[i];
dfs(j);
}
modify(y,-1);
}
int main(){
int t;
cin>>t;
while(t--){
idx=0;
ans=0;
num.clear();
scanf("%d%lld",&n,&k);
memset(h,-1,sizeof h);
memset(in,0,sizeof in);
memset(tr,0,sizeof tr);
int i;
for(i=1;i<=n;i++){
scanf("%d",&a[i]);
num.push_back(a[i]);
}
sort(num.begin(),num.end());
num.erase(unique(num.begin(),num.end()),num.end());
for(i=1;i<=n-1;i++){
ll b,c;
scanf("%lld%lld",&c,&b);
add(c,b);
in[b]++;
}
for(i=1;i<=n;i++){
if(!in[i]){
dfs(i);
break;
}
}
cout<<ans<<endl;
}
return 0;
}