Description
给出一棵
n
个节点的二叉树,
Input
第一行一整数
T
表示用例组数,每组用例首先输入一整数
Output
输出 n 个数表示每个点的最小花费
Sample Input
1
3
1 2 3
1 2
2 3
Sample Output
10 7 3
Solution
每个点的最小花费显然为最大权值
考虑启发式合并,把点权离散化后,用两个树状数组分别维护当前子树中出现了哪些权值的点以及出现点的权值,每次把点数小的子树插入到点数大的子树中,这些每次从树状数组中被插入和删去的点都是位于较小的子树中的,所以最多被操作 logn 次,总时间复杂度 O(nlog2n)
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
#define maxn 100005
int T,n,m,a[maxn],val[maxn],id[maxn],ls[maxn],rs[maxn],sz[maxn];
ll Ans,ans[maxn];
vector<int>g[maxn];
struct BIT
{
#define lowbit(x) (x&(-x))
ll b[maxn];
void init()
{
memset(b,0,sizeof(b));
}
void update(int x,int v)
{
while(x<=m)
{
b[x]+=v;
x+=lowbit(x);
}
}
ll query(int x)
{
ll ans=0;
while(x)
{
ans+=b[x];
x-=lowbit(x);
}
return ans;
}
}num,sum;
void dfs(int u,int fa)
{
sz[u]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
if(!ls[u])ls[u]=v;
else rs[u]=v;
dfs(v,u);
sz[u]+=sz[v];
}
if(ls[u]&&rs[u]&&sz[ls[u]]>sz[rs[u]])swap(ls[u],rs[u]);
}
void add(int x)
{
Ans+=((num.query(m)-num.query(x)+1)*val[x]+sum.query(x));
num.update(x,1),sum.update(x,val[x]);
}
void del(int x)
{
num.update(x,-1),sum.update(x,-val[x]);
Ans-=((num.query(m)-num.query(x)+1)*val[x]+sum.query(x));
}
void Add(int u)
{
add(id[u]);
if(ls[u])Add(ls[u]);
if(rs[u])Add(rs[u]);
}
void Del(int u)
{
del(id[u]);
if(ls[u])Del(ls[u]);
if(rs[u])Del(rs[u]);
}
void Solve(int u)
{
if(!ls[u])
{
ans[u]=a[u];
add(id[u]);
return ;
}
Solve(ls[u]);
if(rs[u])
{
Del(ls[u]);
Solve(rs[u]);
Add(ls[u]);
}
add(id[u]);
ans[u]=Ans;
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
Ans=0;
num.init(),sum.init();
for(int i=1;i<=n;i++)
{
g[i].clear();
ls[i]=rs[i]=0;
}
for(int i=1;i<=n;i++)scanf("%d",&a[i]),val[i]=a[i];
sort(val+1,val+n+1);
m=unique(val+1,val+n+1)-val-1;
for(int i=1;i<=n;i++)id[i]=lower_bound(val+1,val+m+1,a[i])-val;
//for(int i=1;i<=n;i++)printf("id=%d val=%d\n",id[i],val[id[i]]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v),g[v].push_back(u);
}
dfs(1,0);
Solve(1);
for(int i=1;i<=n;i++)printf("%I64d ",ans[i]);
printf("\n");
}
return 0;
}