老实说比赛的时候没做出来,后来看了题解,然后发现有很多种办法可以做,一种解法是用Treap,我第一次看到这个数据结构,以后补发用Treap写的方法,先上用树状数组和线段树的做法。
其实树状数组和线段树的做法差不多,都是把这个问题转化到求区间【1,k/value】里的元素个数上面来,只不过需要离散化,因为这里的数据很大,显然不能开 pow(10,9)的数组。而且在离散化的时候不仅仅需要把每个节点自己的value离散掉,还需要把k/value一起离散掉,然后求出相对大小。然后还需要注意的一点就是特判0的情况,而特判0的时候就相当于把区间【1,n<<1】上的元素个数都求一遍。总之学到了一些处理的技巧,挺巧妙的。
至于跑DFS,我觉得应该都能想到吧。。问题就是在于求父节点和子节点的乘积小于k的问题上
先是树状数组的做法,树状数组毕竟还是要好写一些:
Tips:此处的val1用于存放每个节点的value,val2用于存放每个节点的k/value,arr是用来离散化的
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
struct Edge{
ll v,next;
Edge(ll uv=0,ll ne=0):v(uv),next(ne){};
}edge[100005];
ll t,n,k,arr[(100000<<1)+5],val1[100005],val2[100005],u,v,first[100005],bit[(100000<<1)+5],root,ans,sz=1;
bool flag[100005];
void addeage(),dfs(ll root),update(ll index,ll val);
ll query(ll index);
int main(){
ios_base::sync_with_stdio(false);
cin>>t;
while(t--){
cin>>n>>k;
for(int i=1;i<=n;++i){
cin>>val1[i];
val2[i]=(val1[i])?k/val1[i]:n<<1;
arr[(i<<1)-1]=val1[i];
arr[(i<<1)-2]=val2[i];
}
sort(arr,arr+(n<<1));
for(int i=1;i<=n-1;++i){
cin>>u>>v,addeage(),flag[v]=true;
val1[i]=lower_bound(arr,arr+(n<<1),val1[i])-arr+1;
val2[i]=lower_bound(arr,arr+(n<<1),val2[i])-arr+1;
}
val1[n]=lower_bound(arr,arr+(n<<1),val1[n])-arr+1;
val2[n]=lower_bound(arr,arr+(n<<1),val2[n])-arr+1;
for(int i=1;i<=n;++i)
if(!flag[i])root=i,i=n;
dfs(root);
cout<<ans<<endl;
memset(first+1,0,sizeof(ll)*n);
memset(flag+1,0,sizeof(bool)*n);
ans=0;sz=1;
}
return 0;
}
void addeage(){
edge[sz]=Edge(v,first[u]);
first[u]=sz++;
}
void dfs(ll root){
ans+=query(val2[root]);
update(val1[root],1);
for(ll i = first[root];i;i=edge[i].next)
dfs(edge[i].v);
update(val1[root],-1);
}
void update(ll index,ll val){
while(index<=(n<<1))
bit[index]+=val,index+=(index&-index);
}
ll query(ll index){
ll ans=0;
while(index)
ans+=bit[index],index-=(index&-index);
return ans;
}
接下来时线段树的写法,变量命名和上面一样,只是线段树在区间查询的时候把左端点省略掉了,因为左端点固定就是1嘛。
Tips:在相邻两组测试数据之间线段树不需要初始化
PS:果然线段树写起来要比树状数组麻烦一点。。
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll maxm = 1E+5;
ll n,k,val1[maxm+5],val2[maxm+5],u,v,arr[5+maxm<<1],st[5+maxm<<3],root,sz=1,ans,t,first[maxm+5];
struct Edge{
ll v,next;
Edge(ll uv=0,ll ne=0):v(uv),next(ne){};
}edge[maxm+5];
void addeage(),init(ll node=1,ll le=1,ll ri = n<<1),update(ll val,ll pos,ll node=1,ll le=1,ll ri =n<<1),dfs(ll root);
ll query(ll righ,ll node=1,ll le=1,ll ri = n<<1);
bool flag[maxm+5];
int main(){
ios_base::sync_with_stdio(false);
cin>>t;
while(t--){
cin>>n>>k;
//init();
for(int i=1;i<=n;++i){
cin>>val1[i];
val2[i]=(val1[i])?k/val1[i]:n<<1;
arr[(i<<1)-1]=val1[i];
arr[(i<<1)-2]=val2[i];
}
sort(arr,arr+(n<<1));
for(int i=1;i<=n-1;++i){
cin>>u>>v,addeage(),flag[v]=true;
val1[i]=lower_bound(arr,arr+(n<<1),val1[i])-arr+1;
val2[i]=lower_bound(arr,arr+(n<<1),val2[i])-arr+1;
}
val1[n]=lower_bound(arr,arr+(n<<1),val1[n])-arr+1;
val2[n]=lower_bound(arr,arr+(n<<1),val2[n])-arr+1;
for(int i=1;i<=n;++i)
if(!flag[i])root=i,i=n;
dfs(root);
cout<<ans<<endl;
memset(first+1,0,sizeof(ll)*n);
memset(flag+1,0,sizeof(bool)*n);
sz=1,ans=0;
}
return 0;
}
void init(ll node,ll le,ll ri){
if(le==ri)
st[le]=0;
else{
int k=(le+ri)>>1;
init(node<<1,le,k);
init((node<<1)+1,k+1,ri);
st[node]=0;
}
}
void update(ll val,ll pos,ll node,ll le,ll ri){
if(le==ri)
st[node]+=val;
else{
st[node]+=val;
int k=(le+ri)>>1;
if(pos<=k)
update(val,pos,node<<1,le,k);
else
update(val,pos,(node<<1)+1,k+1,ri);
}
}
void addeage(){
edge[sz]=Edge(v,first[u]);
first[u]=sz++;
}
void dfs(ll root){
ans+=query(val2[root]);
update(1,val1[root]);
for(ll i=first[root];i;i=edge[i].next)
dfs(edge[i].v);
update(-1,val1[root]);
}
ll query(ll righ,ll node,ll le,ll ri){
if(ri<=righ)return st[node];
ll k =(le+ri)>>1;
if(righ<=k)
return query(righ,node<<1,le,k);
else
return query(righ,node<<1,le,k)+query(righ,(node<<1)+1,k+1,ri);
}