题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=5678
题目大意:
给出一颗根为1的树,有n个点,每个点都有一个value值。然后给出n-1条边,表示有两个点之间是相连的,最后形成一颗树。然后给出m个询问,每次询问节点x以下所有节点形成的序列的中位数。
范围:
n<=10^5,m<=10^6,1<=val<=10^9。
思路:
主席树模板题。
对于节点x以下的节点形成的序列求中位数,我们很容易想到利用主席树求任意区间上的第k大来获得中位数。所以重点就是如何将节点x以下的节点构成一个序列放到区间里面。其实也很简单,因为他要的是所有子树,所以只需要一个dfs序处理出来就可以了。
注意要先预处理出n以内的所有节点的中位数,否则会超时。
代码:
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<iostream>
#include<set>
#include<map>
#define ll __int64
#define M 100005
#define mod 1000000007
using namespace std;
int n,m;
int val[M],head[M],tot,in[M],out[M],tt;
int u[M*10],v[M*10],a[M*10];
int T[M],lson[M*30],rson[M*30],c[M*30],cnt,kk;
ll ans[M];
int x[M*10];
ll MM[M*10];
void init_MM()
{
MM[0]=1;
for(int i=1;i<=1000002;i++)
MM[i]=MM[i-1]*10%mod;
}
struct edge
{
int to,next;
}edge[M*10];
void init_hash()
{
for(int i=1;i<=n;i++)
{
a[i]=val[i];
}
sort(val+1,val+1+n);
int siz=unique(val+1,val+1+n)-val-1;
kk=siz;
}
void add(int a,int b)
{
tot++;
edge[tot].to=b;
edge[tot].next=head[a];
head[a]=tot;
}
int update(int root,int pos ,int val)
{
int newroot=cnt++;
int tmp=newroot;
c[newroot]=c[root]+val;
int l=1,r=kk;
while(l<r)
{
int mid=l+r>>1;
if(mid>=pos)
{
lson[newroot]=cnt++;
rson[newroot]=rson[root];
root=lson[root];
newroot=lson[newroot];
r=mid;
}
else
{
lson[newroot]=lson[root];
rson[newroot]=cnt++;
root=rson[root];
newroot=rson[newroot];
l=mid+1;
}
c[newroot]=c[root]+val;
}
return tmp;
}
int hash1(int x)
{
return lower_bound(val+1,val+1+kk,x)-val;
}
void dfs(int x,int la) //dfs序处理
{
in[x]=++tt;
T[tt]=update(T[tt-1],hash1(a[x]),1);
for(int i=head[x];i!=-1;i=edge[i].next)
{
int yy = edge[i].to;
if(yy==la)continue;
dfs(yy,x);
}
out[x]=tt;
}
int build(int l,int r)
{
int root=cnt++;
c[root]=0;
int mid=l+r>>1;
if(l!=r)
{
lson[root]=build(l,mid);
rson[root]=build(mid+1,r);
}
return root;
}
int query(int l_root,int r_root,int k)
{
int l=1,r=kk;
while(l<r)
{
int mid=l+r>>1;
if(c[lson[r_root]]-c[lson[l_root]]>=k)
{
r=mid;
r_root=lson[r_root];
l_root=lson[l_root];
}
else
{
k-=(c[lson[r_root]]-c[lson[l_root]]);
l=mid+1;
r_root=rson[r_root];
l_root=rson[l_root];
}
}
return l;
}
int main()
{
int Ti,i,j,k,ff;
scanf("%d",&Ti);
init_MM();
while(Ti--)
{
ff=0;
tt=0;
tot=0;
memset(in,0,sizeof(in));
memset(out,0,sizeof(out));
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)
{
scanf("%d",&val[i]);
}
for(i=1;i<n;i++)
{
scanf("%d%d",&u[i],&v[i]);
add(u[i],v[i]);
add(v[i],u[i]);
}
init_hash();
cnt=0;
T[0]=build(1,kk);
dfs(1,-1);
ll sum=0;
for(i=1;i<=m;i++)
scanf("%d",&x[i]);
for(i=1;i<=n;i++)
{
ll temp;
int left,right;
left=in[i];
right=out[i];
if((left+right)%2==0)
{
temp=val[query(T[left-1],T[right],(left+right)/2-left+1)];
ans[i]=temp*10;
}
else
{
temp=val[query(T[left-1],T[right],(right+left)/2-left+1)]+val[query(T[left-1],T[right],(left+right)/2+1-left+1)];
ans[i]=temp*5;
}
}
for(i=1;i<m;i++)
{
ll temp;
sum=sum%mod+ans[x[i]]*MM[m-i-1]%mod;
sum=sum%mod;
}
ll temp;
sum=sum%mod+(ans[x[m]]/10)%mod;
sum=sum%mod;
if(ans[x[m]]%10)ff=1;
if(ff)
printf("%.1f\n",(double)sum+0.5);
else printf("%.1f\n",(double)sum);
}
}